Exemple #1
0
def test():
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)

    model = cnn_MNIST()
    checkpoint = torch.load("../examples/vision/pretrain/mnist_cnn_small.pth",
                            map_location="cpu")
    model.load_state_dict(checkpoint)

    N = 2
    n_classes = 10
    image = torch.randn(N, 1, 28, 28)
    image = image.to(torch.float32) / 255.0

    model = BoundedModule(model, torch.empty_like(image), device="cpu")
    eps = 0.3
    norm = np.inf
    ptb = PerturbationLpNorm(norm=norm, eps=eps)
    image = BoundedTensor(image, ptb)
    pred = model(image)
    lb, ub = model.compute_bounds()

    assert lb.shape == ub.shape == torch.Size((2, 10))

    path = 'data/constant_test_data'
    if args.gen_ref:
        torch.save((lb, ub), path)
    else:
        lb_ref, ub_ref = torch.load(path)
        print(lb)
        print(lb_ref)
        assert torch.allclose(lb, lb_ref)
        assert torch.allclose(ub, ub_ref)
    def test(self):
        model_oris = [
            models.model_resnet(width=1, mult=2),
            models.ResNet18(in_planes=2)
        ]
        self.result = []

        for model_ori in model_oris:
            conv_mode = 'patches'  # conv_mode can be set as 'matrix' or 'patches'

            normalize = torchvision.transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
            test_data = torchvision.datasets.CIFAR10(
                "./data",
                train=False,
                download=True,
                transform=torchvision.transforms.Compose(
                    [torchvision.transforms.ToTensor(), normalize]))
            N = 1
            n_classes = 10

            image = torch.Tensor(test_data.data[:N]).reshape(N, 3, 32, 32)
            image = image.to(torch.float32) / 255.0

            model = BoundedModule(model_ori,
                                  image,
                                  bound_opts={"conv_mode": conv_mode})

            ptb = PerturbationLpNorm(norm=np.inf, eps=0.03)
            image = BoundedTensor(image, ptb)
            pred = model(image)
            lb, ub = model.compute_bounds(IBP=False, C=None, method='backward')
            self.result += [lb, ub]

        self.check()
Exemple #3
0
    def compute_and_compare_bounds(self, eps, norm, IBP, method):
        input_data = torch.randn((N, 256))
        model = BoundedModule(self.original_model,
                              torch.empty_like(input_data))
        ptb = PerturbationLpNorm(norm=norm, eps=eps)
        ptb_data = BoundedTensor(input_data, ptb)
        pred = model(ptb_data)
        label = torch.argmax(pred, dim=1).cpu().detach().numpy()
        # Compute bounds.
        lb, ub = model.compute_bounds(IBP=IBP, method=method)
        # Compute dual norm.
        if norm == 1:
            q = np.inf
        elif norm == np.inf:
            q = 1.0
        else:
            q = 1.0 / (1.0 - (1.0 / norm))
        # Compute reference manually.
        weight, bias = list(model.parameters())
        norm = weight.norm(p=q, dim=1)
        expected_pred = input_data.matmul(weight.t()) + bias
        expected_ub = eps * norm + expected_pred
        expected_lb = -eps * norm + expected_pred

        # Check equivalence.
        self.assertEqual(expected_pred, pred)
        self.assertEqual(expected_ub, ub)
        self.assertEqual(expected_lb, lb)
Exemple #4
0
def test():
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)

    models = [2, 3]
    paddings = [1, 2]
    strides = [1, 3]

    N = 2
    n_classes = 10
    image = torch.randn(N, 1, 28, 28)
    image = image.to(torch.float32) / 255.0

    for layer_num in models:
        for padding in paddings:
            for stride in strides:
                # print(layer_num, padding, stride)
                try:
                    model_ori = cnn_model(layer_num, padding, stride)
                except:
                    continue

                model = BoundedModule(model_ori,
                                      torch.empty_like(image),
                                      device="cpu",
                                      bound_opts={"conv_mode": "patches"})
                eps = 0.3
                norm = np.inf
                ptb = PerturbationLpNorm(norm=norm, eps=eps)
                image = BoundedTensor(image, ptb)
                pred = model(image)
                lb, ub = model.compute_bounds()

                model = BoundedModule(model_ori,
                                      torch.empty_like(image),
                                      device="cpu",
                                      bound_opts={"conv_mode": "matrix"})
                pred = model(image)
                lb_ref, ub_ref = model.compute_bounds()

                assert lb.shape == ub.shape == torch.Size((N, n_classes))
                assert torch.allclose(lb, lb_ref)
                assert torch.allclose(ub, ub_ref)
def test():
    net = ResNet18()
    N = 2
    n_classes = 10
    x = torch.randn(N, 3, 32, 32)
    y = net(x)

    device = 'cpu'
    if device == 'cuda':
        x = x.cuda()
        y = y.cuda()

    model = BoundedModule(net,
                          torch.empty_like(x),
                          bound_opts={"conv_mode": "patches"},
                          device=device)
    print("Model structure: \n", str(net))
    eps = 0.3
    norm = np.inf
    ptb = PerturbationLpNorm(norm=norm, eps=eps)
    image = BoundedTensor(x, ptb)
    pred = model(image)
    lb, ub = model.compute_bounds()

    model = BoundedModule(net,
                          torch.empty_like(x),
                          bound_opts={"conv_mode": "matrix"},
                          device=device)
    eps = 0.3
    norm = np.inf
    ptb = PerturbationLpNorm(norm=norm, eps=eps)
    image = BoundedTensor(x, ptb)
    pred = model(image)
    lb_ref, ub_ref = model.compute_bounds()

    # assert lb.shape == ub.shape == torch.Size((N, n_classes))
    print((lb - lb_ref).sum(), (ub - ub_ref).sum())
    assert torch.allclose(lb, lb_ref)
    assert torch.allclose(ub, ub_ref)
 def compute_and_compare_bounds(self, eps, norm, IBP, method):
     input_data = torch.randn((N, 1, input_dim, input_dim))
     model = BoundedModule(self.original_model,
                           torch.empty_like(input_data))
     ptb = PerturbationLpNorm(norm=norm, eps=eps)
     ptb_data = BoundedTensor(input_data, ptb)
     pred = model(ptb_data)
     label = torch.argmax(pred, dim=1).cpu().detach().numpy()
     # Compute bounds.
     lb, ub = model.compute_bounds(IBP=IBP, method=method)
     # Compute reference.
     conv_weight, conv_bias = list(model.parameters())
     conv_bias = conv_bias.view(1, out_channel, 1, 1)
     matrix_eye = torch.eye(input_dim * input_dim).view(
         input_dim * input_dim, 1, input_dim, input_dim)
     # Obtain equivalent weight and bias for convolution.
     weight = self.original_model.conv(
         matrix_eye
     ) - conv_bias  # Output is (batch, channel, weight, height).
     weight = weight.view(
         input_dim * input_dim,
         -1)  # Dimension is (flattened_input, flattened_output).
     bias = conv_bias.repeat(1, 1, input_dim // 2, input_dim // 2).view(-1)
     flattend_data = input_data.view(N, -1)
     # Compute dual norm.
     if norm == 1:
         q = np.inf
     elif norm == np.inf:
         q = 1.0
     else:
         q = 1.0 / (1.0 - (1.0 / norm))
     # Manually compute bounds.
     norm = weight.t().norm(p=q, dim=1)
     expected_pred = flattend_data.matmul(weight) + bias
     expected_ub = eps * norm + expected_pred
     expected_lb = -eps * norm + expected_pred
     # Check equivalence.
     if method == 'backward' or method == 'forward':
         self.assertEqual(expected_pred, pred)
         self.assertEqual(expected_ub, ub)
         self.assertEqual(expected_lb, lb)
Exemple #7
0
    def test(self):
        model = cnn_MNIST()
        checkpoint = torch.load(
            "../examples/vision/pretrain/mnist_cnn_small.pth",
            map_location="cpu")
        model.load_state_dict(checkpoint)

        N = 2
        n_classes = 10
        image = torch.randn(N, 1, 28, 28)
        image = image.to(torch.float32) / 255.0

        model = BoundedModule(model, torch.empty_like(image), device="cpu")
        eps = 0.3
        norm = np.inf
        ptb = PerturbationLpNorm(norm=norm, eps=eps)
        image = BoundedTensor(image, ptb)
        pred = model(image)
        lb, ub = model.compute_bounds()

        assert lb.shape == ub.shape == torch.Size((2, 10))

        self.result = (lb, ub)
        self.check()
                      image,
                      bound_opts={"conv_mode": conv_mode},
                      device=device)

## Step 4: Compute bounds using LiRPA given a perturbation
eps = 0.03
norm = np.inf
ptb = PerturbationLpNorm(norm=norm, eps=eps)
image = BoundedTensor(image, ptb)
# Get model prediction as usual
pred = model(image)

# Compute bounds
torch.cuda.empty_cache()
print('Using {} mode to compute convolution.'.format(conv_mode))
lb, ub = model.compute_bounds(IBP=False, C=None, method='backward')

## Step 5: Final output
# pred = pred.detach().cpu().numpy()
lb = lb.detach().cpu().numpy()
ub = ub.detach().cpu().numpy()
for i in range(N):
    # print("Image {} top-1 prediction {}".format(i, label[i]))
    for j in range(n_classes):
        print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}".format(
            j=j, l=lb[i][j], u=ub[i][j]))
    print()

# Print the GPU memory usage
print('Memory usage in "{}" mode:'.format(conv_mode))
print(torch.cuda.memory_summary())
Exemple #9
0
    if batch_idx < 2:

        ## Step 3: wrap model with auto_LiRPA
        # The second parameter is for constructing the trace of the computational graph, and its content is not important.

        model = BoundedModule(model, inputs, device="cuda")

        ## Step 4: Compute bounds using LiRPA given a perturbation
        eps = 0.3
        norm = np.inf
        ptb = PerturbationLpNorm(norm=norm, eps=eps)
        image = BoundedTensor(inputs, ptb)
        # Get model prediction as usual
        pred = model(image)
        label = torch.argmax(pred, dim=1).cpu().numpy()
        # Compute bounds
        lb, ub = model.compute_bounds()

        ## Step 5: Final output
        pred = pred.detach().cpu().numpy()
        lb = lb.detach().cpu().numpy()
        ub = ub.detach().cpu().numpy()
        for i in range(N):
            print("Image {} top-1 prediction {}".format(i, label[i]))
            for j in range(n_classes):
                print(
                    "f_{j}(x_0) = {fx0:8.3f},   {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}"
                    .format(j=j, fx0=pred[i][j], l=lb[i][j], u=ub[i][j]))
            print()
class LiRPAConvNet:
    def __init__(self,
                 model_ori,
                 pred,
                 test,
                 solve_slope=False,
                 device='cuda',
                 simplify=True,
                 in_size=(1, 3, 32, 32)):
        """
        convert pytorch model to auto_LiRPA module
        """

        layers = list(model_ori.children())
        if simplify:
            added_prop_layers = add_single_prop(layers, pred, test)
            self.layers = added_prop_layers
        else:
            self.layers = layers
        net = nn.Sequential(*self.layers)
        self.solve_slope = solve_slope
        if solve_slope:
            self.net = BoundedModule(net,
                                     torch.rand(in_size),
                                     bound_opts={
                                         'relu': 'random_evaluation',
                                         'conv_mode': 'patches'
                                     },
                                     device=device)
        else:
            self.net = BoundedModule(net,
                                     torch.rand(in_size),
                                     bound_opts={'relu': 'same-slope'},
                                     device=device)
        self.net.eval()

    def get_lower_bound(self,
                        pre_lbs,
                        pre_ubs,
                        decision,
                        slopes=None,
                        history=[],
                        decision_thresh=0,
                        layer_set_bound=True,
                        beta=True):
        """
        # (in) pre_lbs: layers list -> tensor(batch, layer shape)
        # (in) relu_mask: relu layers list -> tensor(batch, relu layer shape (view-1))
        # (in) slope: relu layers list -> tensor(batch, relu layer shape)
        # (out) lower_bounds: batch list -> layers list -> tensor(layer shape)
        # (out) masks_ret: batch list -> relu layers list -> tensor(relu layer shape)
        # (out) slope: batch list -> relu layers list -> tensor(relu layer shape)
        """
        start = time.time()
        lower_bounds, upper_bounds, masks_ret, slopes = self.update_bounds_parallel(
            pre_lbs,
            pre_ubs,
            decision,
            slopes,
            beta=beta,
            early_stop=False,
            opt_choice="adam",
            iteration=20,
            history=history,
            decision_thresh=decision_thresh,
            layer_set_bound=layer_set_bound)

        end = time.time()
        print('batch time: ', end - start)
        return [i[-1] for i in upper_bounds
                ], [i[-1] for i in lower_bounds
                    ], None, masks_ret, lower_bounds, upper_bounds, slopes

    def get_relu(self, model, idx):
        # find the i-th ReLU layer
        i = 0
        for layer in model.children():
            if isinstance(layer, BoundRelu):
                i += 1
                if i == idx:
                    return layer

    def get_candidate(self, model, lb, ub):
        # get the intermediate bounds in the current model and build self.name_dict which contains the important index
        # and model name pairs

        if self.input_domain.ndim == 2:
            lower_bounds = [self.input_domain[:, 0].squeeze(-1)]
            upper_bounds = [self.input_domain[:, 1].squeeze(-1)]
        else:
            lower_bounds = [self.input_domain[:, :, :, 0].squeeze(-1)]
            upper_bounds = [self.input_domain[:, :, :, 1].squeeze(-1)]
        self.pre_relu_indices = []
        idx, i, model_i = 0, 0, 0
        # build a name_dict to map layer idx in self.layers to BoundedModule
        self.name_dict = {0: model.root_name[0]}
        model_names = list(model._modules)

        for layer in self.layers:
            if isinstance(layer, nn.ReLU):
                i += 1
                this_relu = self.get_relu(model, i)
                lower_bounds[-1] = this_relu.inputs[0].lower.squeeze().detach()
                upper_bounds[-1] = this_relu.inputs[0].upper.squeeze().detach()
                lower_bounds.append(F.relu(lower_bounds[-1]).detach())
                upper_bounds.append(F.relu(upper_bounds[-1]).detach())
                self.pre_relu_indices.append(idx)
                self.name_dict[idx + 1] = model_names[model_i]
                model_i += 1
            elif isinstance(layer, Flatten):
                lower_bounds.append(lower_bounds[-1].reshape(-1).detach())
                upper_bounds.append(upper_bounds[-1].reshape(-1).detach())
                self.name_dict[idx + 1] = model_names[model_i]
                model_i += 8  # Flatten is split to 8 ops in BoundedModule
            elif isinstance(layer, ZeroPad2d):
                lower_bounds.append(F.pad(lower_bounds[-1], layer.padding))
                upper_bounds.append(F.pad(upper_bounds[-1], layer.padding))
                self.name_dict[idx + 1] = model_names[model_i]
                model_i += 24
            else:
                self.name_dict[idx + 1] = model_names[model_i]
                lower_bounds.append([])
                upper_bounds.append([])
                model_i += 1
            idx += 1

        # Also add the bounds on the final thing
        lower_bounds[-1] = (lb.view(-1).detach())
        upper_bounds[-1] = (ub.view(-1).detach())

        return lower_bounds, upper_bounds, self.pre_relu_indices

    def get_candidate_parallel(self, model, lb, ub, batch):
        # get the intermediate bounds in the current model
        lower_bounds = [
            self.input_domain[:, :, :, 0].squeeze(-1).repeat(batch, 1, 1, 1)
        ]
        upper_bounds = [
            self.input_domain[:, :, :, 1].squeeze(-1).repeat(batch, 1, 1, 1)
        ]
        idx, i, = 0, 0
        for layer in self.layers:
            if isinstance(layer, nn.ReLU):
                i += 1
                this_relu = self.get_relu(model, i)
                lower_bounds[-1] = this_relu.inputs[0].lower.detach()
                upper_bounds[-1] = this_relu.inputs[0].upper.detach()
                lower_bounds.append(F.relu(lower_bounds[-1]).detach(
                ))  # TODO we actually do not need the bounds after ReLU
                upper_bounds.append(F.relu(upper_bounds[-1]).detach())
            elif isinstance(layer, Flatten):
                lower_bounds.append(lower_bounds[-1].reshape(batch,
                                                             -1).detach())
                upper_bounds.append(upper_bounds[-1].reshape(batch,
                                                             -1).detach())
            elif isinstance(layer, nn.ZeroPad2d):
                lower_bounds.append(
                    F.pad(lower_bounds[-1], layer.padding).detach())
                upper_bounds.append(
                    F.pad(upper_bounds[-1], layer.padding).detach())

            else:
                lower_bounds.append([])
                upper_bounds.append([])
            idx += 1

        # Also add the bounds on the final thing
        lower_bounds[-1] = (lb.view(batch, -1).detach())
        upper_bounds[-1] = (ub.view(batch, -1).detach())

        return lower_bounds, upper_bounds

    def get_mask_parallel(self, model):
        # get the mask of status of ReLU, 0 means inactive neurons, -1 means unstable neurons, 1 means active neurons
        mask = []
        idx, i, = 0, 0
        for layer in self.layers:
            if isinstance(layer, nn.ReLU):
                i += 1
                this_relu = self.get_relu(model, i)
                mask_tmp = torch.zeros_like(this_relu.inputs[0].lower)
                unstable = ((this_relu.inputs[0].lower < 0) &
                            (this_relu.inputs[0].upper > 0))
                mask_tmp[unstable] = -1
                active = (this_relu.inputs[0].lower >= 0)
                mask_tmp[active] = 1
                # otherwise 0, for inactive neurons

                mask.append(mask_tmp.reshape(mask_tmp.size(0), -1))

        ret = []
        for i in range(mask[0].size(0)):
            ret.append([j[i] for j in mask])

        return ret

    def get_beta(self, model):
        b = []
        bm = []
        for m in model._modules.values():
            if isinstance(m, BoundRelu):
                b.append(m.beta.clone().detach())
                bm.append(m.beta_mask.clone().detach())

        retb = []
        retbm = []
        for i in range(b[0].size(0)):
            retb.append([j[i] for j in b])
            retbm.append([j[i] for j in bm])
        return (retb, retbm)

    def get_slope(self, model):
        s = []
        for m in model._modules.values():
            if isinstance(m, BoundRelu):
                s.append(m.slope.transpose(0, 1).clone().detach())

        ret = []
        for i in range(s[0].size(0)):
            ret.append([j[i] for j in s])
        return ret

    def set_slope(self, model, slope):
        idx = 0
        for m in model._modules.values():
            if isinstance(m, BoundRelu):
                # m.slope = slope[idx].repeat(2, *([1] * (slope[idx].ndim - 1))).requires_grad_(True)
                m.slope = slope[idx].repeat(
                    2, *([1] * (slope[idx].ndim - 1))).transpose(
                        0, 1).requires_grad_(True)
                idx += 1

    def reset_beta(self, model, batch=0):
        if batch == 0:
            for m in model._modules.values():
                if isinstance(m, BoundRelu):
                    m.beta.data = m.beta.data * 0.
                    m.beta_mask.data = m.beta_mask.data * 0.
                    # print("beta[{}]".format(batch), m.beta.shape, m.beta_mask.shape)
        else:
            for m in model._modules.values():
                if isinstance(m, BoundRelu):
                    ndim = m.beta.data.ndim
                    # m.beta.data=(m.beta.data[0:1]*0.).repeat(batch*2, *([1] * (ndim - 1))).requires_grad_(True)
                    # m.beta_mask.data=(m.beta_mask.data[0:1]*0.).repeat(batch*2, *([1] * (ndim - 1))).requires_grad_(True)
                    m.beta = torch.zeros(m.beta[:, 0:1].shape).repeat(
                        1, batch * 2, *([1] * (ndim - 2))).detach().to(
                            m.beta.device).requires_grad_(True)
                    m.beta_mask = torch.zeros(m.beta_mask[0:1].shape).repeat(
                        batch * 2, *([1] * (ndim - 2))).detach().to(
                            m.beta.device).requires_grad_(False)
                    # print("beta[{}]".format(batch), m.beta.shape, m.beta_mask.shape)

    def update_bounds_parallel(self,
                               pre_lb_all=None,
                               pre_ub_all=None,
                               decision=None,
                               slopes=None,
                               beta=True,
                               early_stop=True,
                               opt_choice="default",
                               iteration=20,
                               history=[],
                               decision_thresh=0,
                               layer_set_bound=True):
        # update optimize-CROWN bounds in a parallel way
        total_batch = len(decision)
        decision = np.array(decision)

        layers_need_change = np.unique(decision[:, 0])
        layers_need_change.sort()

        # initial results with empty list
        ret_l = [[] for _ in range(len(decision) * 2)]
        ret_u = [[] for _ in range(len(decision) * 2)]
        masks = [[] for _ in range(len(decision) * 2)]
        ret_s = [[] for _ in range(len(decision) * 2)]

        pre_lb_all_cp = copy.deepcopy(pre_lb_all)
        pre_ub_all_cp = copy.deepcopy(pre_ub_all)

        for idx in layers_need_change:
            # iteratively change upper and lower bound from former to later layer
            tmp_d = np.argwhere(decision[:, 0] == idx)  # .squeeze()
            # idx is the index of relu layers, change_idx is the index of all layers
            change_idx = self.pre_relu_indices[idx]

            batch = len(tmp_d)
            select_history = [
                history[idx] for idx in tmp_d.squeeze().reshape(-1)
            ]

            if beta:
                # update beta mask, put it after reset_beta
                # reset beta according to the shape of batch
                self.reset_beta(self.net, batch)

                # print("select history", select_history)

                bound_relus = []
                for m in self.net._modules.values():
                    if isinstance(m, BoundRelu):
                        bound_relus.append(m)
                        m.beta_mask.data = m.beta_mask.data.view(batch * 2, -1)

                for bi in range(batch):
                    d = tmp_d[bi][0]
                    # assign current decision to each point of a batch
                    bound_relus[int(decision[d][0])].beta_mask.data[
                        bi, int(decision[d][1])] = 1
                    bound_relus[int(decision[d][0])].beta_mask.data[
                        bi + batch, int(decision[d][1])] = -1
                    # print("assign", bi, decision[d], 1, bound_relus[decision[d][0]].beta_mask.data[bi, decision[d][1]])
                    # print("assign", bi+batch, decision[d], -1, bound_relus[decision[d][0]].beta_mask.data[bi+batch, decision[d][1]])
                    # assign history decision according to select_history
                    for (hid, hl), hc in select_history[bi]:
                        bound_relus[hid].beta_mask.data[bi, hl] = int(
                            (hc - 0.5) * 2)
                        bound_relus[hid].beta_mask.data[bi + batch, hl] = int(
                            (hc - 0.5) * 2)
                        # print("assign", bi, [hid, hl], hc, bound_relus[hid].beta_mask.data[bi, hl])
                        # print("assign", bi+batch, [hid, hl], hc, bound_relus[hid].beta_mask.data[bi+batch, hl])

                # sanity check: beta_mask should only be assigned for split nodes
                for m in bound_relus:
                    m.beta_mask.data = m.beta_mask.data.view(m.beta[0].shape)

            slope_select = [i[tmp_d.squeeze()].clone() for i in slopes]

            pre_lb_all = [i[tmp_d.squeeze()].clone() for i in pre_lb_all_cp]
            pre_ub_all = [i[tmp_d.squeeze()].clone() for i in pre_ub_all_cp]

            if batch == 1:
                pre_lb_all = [i.clone().unsqueeze(0) for i in pre_lb_all]
                pre_ub_all = [i.clone().unsqueeze(0) for i in pre_ub_all]
                slope_select = [i.clone().unsqueeze(0) for i in slope_select]

            upper_bounds = [i.clone() for i in pre_ub_all[:change_idx + 1]]
            lower_bounds = [i.clone() for i in pre_lb_all[:change_idx + 1]]
            upper_bounds_cp = copy.deepcopy(upper_bounds)
            lower_bounds_cp = copy.deepcopy(lower_bounds)

            for i in range(batch):
                d = tmp_d[i][0]
                upper_bounds[change_idx].view(batch, -1)[i][decision[d][1]] = 0
                lower_bounds[change_idx].view(batch, -1)[i][decision[d][1]] = 0

            pre_lb_all = [torch.cat(2 * [i]) for i in pre_lb_all]
            pre_ub_all = [torch.cat(2 * [i]) for i in pre_ub_all]

            # merge the inactive and active splits together
            new_candidate = {}
            for i, (l, uc, lc, u) in enumerate(
                    zip(lower_bounds, upper_bounds_cp, lower_bounds_cp,
                        upper_bounds)):
                # we set lower = 0 in first half batch, and upper = 0 in second half batch
                new_candidate[self.name_dict[i]] = [
                    torch.cat((l, lc), dim=0),
                    torch.cat((uc, u), dim=0)
                ]

            if not layer_set_bound:
                new_candidate_p = {}
                for i, (l,
                        u) in enumerate(zip(pre_lb_all[:-2], pre_ub_all[:-2])):
                    # we set lower = 0 in first half batch, and upper = 0 in second half batch
                    new_candidate_p[self.name_dict[i]] = [l, u]

            # create new_x here since batch may change
            ptb = PerturbationLpNorm(
                norm=self.x.ptb.norm,
                eps=self.x.ptb.eps,
                x_L=self.x.ptb.x_L.repeat(batch * 2, 1, 1, 1),
                x_U=self.x.ptb.x_U.repeat(batch * 2, 1, 1, 1))
            new_x = BoundedTensor(self.x.data.repeat(batch * 2, 1, 1, 1), ptb)
            self.net(
                new_x
            )  # batch may change, so we need to do forward to set some shapes here

            if len(slope_select) > 0:
                # set slope here again
                self.set_slope(self.net, slope_select)

            torch.cuda.empty_cache()
            if layer_set_bound:
                # we fix the intermediate bounds before change_idx-th layer by using normal CROWN
                if self.solve_slope and change_idx >= self.pre_relu_indices[-1]:
                    # we split the ReLU at last layer, directly use Optimized CROWN
                    self.net.set_bound_opts({
                        'ob_start_idx':
                        sum(change_idx <= x for x in self.pre_relu_indices),
                        'ob_beta':
                        beta,
                        'ob_update_by_layer':
                        layer_set_bound,
                        'ob_iteration':
                        iteration
                    })
                    lb, ub, = self.net.compute_bounds(
                        x=(new_x, ),
                        IBP=False,
                        C=None,
                        method='CROWN-Optimized',
                        new_interval=new_candidate,
                        return_A=False,
                        bound_upper=False)
                else:
                    # we split the ReLU before the last layer, calculate intermediate bounds by using normal CROWN
                    self.net.set_relu_used_count(
                        sum(change_idx <= x for x in self.pre_relu_indices))
                    with torch.no_grad():
                        lb, ub, = self.net.compute_bounds(
                            x=(new_x, ),
                            IBP=False,
                            C=None,
                            method='backward',
                            new_interval=new_candidate,
                            bound_upper=False,
                            return_A=False)

                # we don't care about the upper bound of the last layer
                lower_bounds_new, upper_bounds_new = self.get_candidate_parallel(
                    self.net, lb, lb + 99, batch * 2)

                if change_idx < self.pre_relu_indices[-1]:
                    # check whether we have a better bounds before, and preset all intermediate bounds
                    for i, (l, u) in enumerate(
                            zip(lower_bounds_new[change_idx + 2:-1],
                                upper_bounds_new[change_idx + 2:-1])):
                        new_candidate[self.name_dict[i + change_idx + 2]] = [
                            torch.max(l, pre_lb_all[i + change_idx + 2]),
                            torch.min(u, pre_ub_all[i + change_idx + 2])
                        ]

                    if self.solve_slope:
                        self.net.set_bound_opts({
                            'ob_start_idx':
                            sum(change_idx <= x
                                for x in self.pre_relu_indices),
                            'ob_beta':
                            beta,
                            'ob_update_by_layer':
                            layer_set_bound,
                            'ob_iteration':
                            iteration
                        })
                        lb, ub, = self.net.compute_bounds(
                            x=(new_x, ),
                            IBP=False,
                            C=None,
                            method='CROWN-Optimized',
                            new_interval=new_candidate,
                            return_A=False,
                            bound_upper=False)
                    else:
                        self.net.set_relu_used_count(
                            sum(change_idx <= x
                                for x in self.pre_relu_indices))
                        with torch.no_grad():
                            lb, ub, = self.net.compute_bounds(
                                x=(new_x, ),
                                IBP=False,
                                C=None,
                                method='backward',
                                new_interval=new_candidate,
                                bound_upper=False,
                                return_A=False)

            else:
                # all intermediate bounds are re-calculate by optimized CROWN
                self.net.set_bound_opts({
                    'ob_start_idx': 99,
                    'ob_beta': beta,
                    'ob_update_by_layer': layer_set_bound,
                    'ob_iteration': iteration
                })
                lb, ub, = self.net.compute_bounds(x=(new_x, ),
                                                  IBP=False,
                                                  C=None,
                                                  method='CROWN-Optimized',
                                                  new_interval=new_candidate_p,
                                                  return_A=False,
                                                  bound_upper=False)

            # print('best results of parent nodes', pre_lb_all[-1].repeat(2, 1))
            # print('finally, after optimization:', lower_bounds_new[-1])

            # primal = self.get_primals(A_dict, return_x=True)
            lower_bounds_new, upper_bounds_new = self.get_candidate_parallel(
                self.net, lb, lb + 99, batch * 2)

            lower_bounds_new[-1] = torch.max(lower_bounds_new[-1],
                                             pre_lb_all[-1])
            upper_bounds_new[-1] = torch.min(upper_bounds_new[-1],
                                             pre_ub_all[-1])

            mask = self.get_mask_parallel(self.net)
            if len(slope_select) > 0:
                slope = self.get_slope(self.net)

            # reshape the results
            for i in range(len(tmp_d)):
                ret_l[int(tmp_d[i])] = [j[i] for j in lower_bounds_new]
                ret_l[int(tmp_d[i] + total_batch)] = [
                    j[i + batch] for j in lower_bounds_new
                ]

                ret_u[int(tmp_d[i])] = [j[i] for j in upper_bounds_new]
                ret_u[int(tmp_d[i] + total_batch)] = [
                    j[i + batch] for j in upper_bounds_new
                ]

                masks[int(tmp_d[i])] = mask[i]
                masks[int(tmp_d[i] + total_batch)] = mask[i + batch]
                if len(slope_select) > 0:
                    ret_s[int(tmp_d[i])] = slope[i]
                    ret_s[int(tmp_d[i] + total_batch)] = slope[i + batch]

        return ret_l, ret_u, masks, ret_s

    def fake_forward(self, x):
        for layer in self.layers:
            if type(layer) is nn.Linear:
                x = F.linear(x, layer.weight, layer.bias)
            elif type(layer) is nn.Conv2d:
                x = F.conv2d(x, layer.weight, layer.bias, layer.stride,
                             layer.padding, layer.dilation, layer.groups)
            elif type(layer) == nn.ReLU:
                x = F.relu(x)
            elif type(layer) == Flatten:
                x = x.reshape(x.shape[0], -1)
            elif type(layer) == nn.ZeroPad2d:
                x = F.pad(x, layer.padding)
            else:
                print(type(layer))
                raise NotImplementedError

        return x

    def get_primals(self, A, return_x=False):
        # get primal input by using A matrix
        input_A_lower = A[self.layer_names[-1]][self.net.input_name[0]][0]
        batch = input_A_lower.shape[1]
        l = self.input_domain[:, :, :, 0].repeat(batch, 1, 1, 1)
        u = self.input_domain[:, :, :, 1].repeat(batch, 1, 1, 1)
        diff = 0.5 * (l - u)  # already flip the sign by using lower - upper
        net_input = diff * torch.sign(input_A_lower.squeeze(0)) + self.x
        if return_x: return net_input

        primals = [net_input]
        for layer in self.layers:
            if type(layer) is nn.Linear:
                pre = primals[-1]
                primals.append(F.linear(pre, layer.weight, layer.bias))
            elif type(layer) is nn.Conv2d:
                pre = primals[-1]
                primals.append(
                    F.conv2d(pre, layer.weight, layer.bias, layer.stride,
                             layer.padding, layer.dilation, layer.groups))
            elif type(layer) == nn.ReLU:
                primals.append(F.relu(primals[-1]))
            elif type(layer) == Flatten:
                primals.append(primals[-1].reshape(primals[-1].shape[0], -1))
            else:
                print(type(layer))
                raise NotImplementedError

        # primals = primals[1:]
        primals = [i.detach().clone() for i in primals]
        # print('primals', primals[-1])

        return net_input, primals

    def get_relu_mask(self):
        relu_mask = []
        relu_idx = 0
        for layer in self.layers:
            if type(layer) == nn.ReLU:
                relu_idx += 1
                this_relu = self.get_relu(self.net, relu_idx)
                new_layer_mask = []
                ratios_all = this_relu.d.squeeze(0)
                for slope in ratios_all.flatten():
                    if slope.item() == 1.0:
                        new_layer_mask.append(1)
                    elif slope.item() == 0.0:
                        new_layer_mask.append(0)
                    else:
                        new_layer_mask.append(-1)
                relu_mask.append(
                    torch.tensor(new_layer_mask).to(self.x.device))

        return relu_mask

    def build_the_model(self, input_domain, x, no_lp=True, decision_thresh=0):
        self.x = x
        self.input_domain = input_domain

        slope_opt = None

        # first get CROWN bounds
        if self.solve_slope:
            self.net.init_slope(self.x)
            self.net.set_bound_opts({
                'ob_iteration': 100,
                'ob_beta': False,
                'ob_alpha': True,
                'ob_opt_choice': "adam",
                'ob_decision_thresh': decision_thresh,
                'ob_early_stop': False,
                'ob_log': False,
                'ob_start_idx': 99,
                'ob_keep_best': True,
                'ob_update_by_layer': True,
                'ob_lr': 0.1
            })
            lb, ub, A_dict = self.net.compute_bounds(x=(x, ),
                                                     IBP=False,
                                                     C=None,
                                                     method='CROWN-Optimized',
                                                     return_A=True,
                                                     bound_upper=False)
            slope_opt = self.get_slope(
                self.net)[0]  # initial with one node only
        else:
            with torch.no_grad():
                lb, ub, A_dict = self.net.compute_bounds(x=(x, ),
                                                         IBP=False,
                                                         C=None,
                                                         method='backward',
                                                         return_A=True)

        # build a complete A_dict
        self.layer_names = list(A_dict[list(A_dict.keys())[-1]].keys())[2:]
        self.layer_names.sort()

        # update bounds
        print('initial CROWN bounds:', lb, ub)
        primals, mini_inp = None, None
        # mini_inp, primals = self.get_primals(self.A_dict)
        lb, ub, pre_relu_indices = self.get_candidate(
            self.net, lb, lb + 99)  # primals are better upper bounds
        duals = None

        return ub[-1], lb[-1], mini_inp, duals, primals, self.get_relu_mask(
        ), lb, ub, pre_relu_indices, slope_opt
Exemple #11
0
class mynet(nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.output = nn.sequential(nn.Linear(5, 10), nn.ReLU(),
                                    nn.Linear(10, 3))

    def forward(self, input):
        return self.features(input)


raw_model = mynet()
bound_model = BoundedModule(raw_model, input_vec)
num_actions = 3
batchsize = 5
label = torch.tensor([0, 2, 1, 1, 0])
bnd_state = BoundedTensor(input_vec, PerturbationLpNorm(norm=np.inf, eps=0.1))

c = torch.eye(3).type_as(input_vec)[label].unsqueeze(1) - torch.eye(3).type_as(
    input_vec).unsqueeze(0)
I = (~(label.data.unsqueeze(1) == torch.arange(3).type_as(
    label.data).unsqueeze(0)))
c = (c[I].view(input_vec.size(0), 2, 3))

pred = bound_model(input_vec)
basic_bound, _ = bound_model.compute_bounds(IBP=False, method='backward')
advance_bound, _ = bound_model.compute_bounds(C=c,
                                              IBP=False,
                                              method='backward')
print(basic_bound.detach().numpy())
print(advance_bound.detach().numpy())
Exemple #12
0
                      bound_opts={"conv_mode": conv_mode},
                      device="cuda")

## Step 4: Compute bounds using LiRPA given a perturbation
eps = 0.03
norm = np.inf
ptb = PerturbationLpNorm(norm=norm, eps=eps)
image = BoundedTensor(image, ptb)
# Get model prediction as usual
pred = model(image)

# Compute bounds
torch.cuda.empty_cache()
print('Using {} mode to compute convolution.'.format(conv_mode))
lb, ub = model.compute_bounds(x=(image, ),
                              IBP=False,
                              C=None,
                              method='backward')

## Step 5: Final output
# pred = pred.detach().cpu().numpy()
lb = lb.detach().cpu().numpy()
ub = ub.detach().cpu().numpy()
for i in range(N):
    # print("Image {} top-1 prediction {}".format(i, label[i]))
    for j in range(n_classes):
        print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}".format(
            j=j, l=lb[i][j], u=ub[i][j]))
    print()

# Print the GPU memory usage
print('Memory usage in "{}" mode:'.format(conv_mode))
Exemple #13
0
image = test_data.data[:N].view(N, 1, 28, 28).cuda()
# Convert to float
image = image.to(torch.float32) / 255.0

## Step 3: wrap model with auto_LiRPA
# The second parameter is for constructing the trace of the computational graph, and its content is not important.
model = BoundedModule(model, torch.empty_like(image), device="cuda")

## Step 4: Compute bounds using LiRPA given a perturbation
eps = 0.3
norm = np.inf
# ptb = PerturbationL0Norm(eps=eps)
ptb = PerturbationLpNorm(norm=norm, eps=eps)
image = BoundedTensor(image, ptb)
# Get model prediction as usual
pred = model(image)
# label = torch.argmax(pred, dim=1).cpu().numpy()
# Compute bounds
lb, ub = model.compute_bounds(IBP=False)

## Step 5: Final output
# pred = pred.detach().cpu().numpy()
lb = lb.detach().cpu().numpy()
ub = ub.detach().cpu().numpy()
for i in range(N):
    # print("Image {} top-1 prediction {}".format(i, label[i]))
    for j in range(n_classes):
        print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}".format(
            j=j, l=lb[i][j], u=ub[i][j]))
    print()
Exemple #14
0
class Trainer():
    '''
    This is a class representing a Policy Gradient trainer, which 
    trains both a deep Policy network and a deep Value network.
    Exposes functions:
    - advantage_and_return
    - multi_actor_step
    - reset_envs
    - run_trajectories
    - train_step
    Trainer also handles all logging, which is done via the "cox"
    library
    '''
    def __init__(self, policy_net_class, value_net_class, params,
                 store, advanced_logging=True, log_every=5):
        '''
        Initializes a new Trainer class.
        Inputs;
        - policy, the class of policy network to use (inheriting from nn.Module)
        - val, the class of value network to use (inheriting from nn.Module)
        - step, a reference to a function to use for the policy step (see steps.py)
        - params, an dictionary with all of the required hyperparameters
        '''
        # Parameter Loading
        self.params = Parameters(params)

        # Whether or not the value network uses the current timestep
        time_in_state = self.VALUE_CALC == "time"

        # Whether to use GPU (as opposed to CPU)
        if not self.CPU:
            torch.set_default_tensor_type("torch.cuda.FloatTensor")

        # Environment Loading
        def env_constructor():
            # Whether or not we should add the time to the state
            horizon_to_feed = self.T if time_in_state else None
            return Env(self.GAME, norm_states=self.NORM_STATES,
                       norm_rewards=self.NORM_REWARDS,
                       params=self.params,
                       add_t_with_horizon=horizon_to_feed,
                       clip_obs=self.CLIP_OBSERVATIONS,
                       clip_rew=self.CLIP_REWARDS,
                       show_env=self.SHOW_ENV,
                       save_frames=self.SAVE_FRAMES,
                       save_frames_path=self.SAVE_FRAMES_PATH)

        self.envs = [env_constructor() for _ in range(self.NUM_ACTORS)]
        self.params.AGENT_TYPE = "discrete" if self.envs[0].is_discrete else "continuous"
        self.params.NUM_ACTIONS = self.envs[0].num_actions
        self.params.NUM_FEATURES = self.envs[0].num_features
        self.policy_step = step_with_mode(self.MODE)
        self.params.MAX_KL_INCREMENT = (self.params.MAX_KL_FINAL - self.params.MAX_KL) / self.params.TRAIN_STEPS
        self.advanced_logging = advanced_logging
        self.n_steps = 0
        self.log_every = log_every

        # Instantiation
        self.policy_model = policy_net_class(self.NUM_FEATURES, self.NUM_ACTIONS,
                                             self.INITIALIZATION,
                                             time_in_state=time_in_state,
                                             activation=self.policy_activation)

        # Instantiate convex relaxation model when mode is 'robust_ppo'
        if self.MODE == 'robust_ppo':
            self.create_relaxed_model(time_in_state)

        opts_ok = (self.PPO_LR == -1 or self.PPO_LR_ADAM == -1)
        assert opts_ok, "One of ppo_lr and ppo_lr_adam must be -1 (off)."
        # Whether we should use Adam or simple GD to optimize the policy parameters
        if self.PPO_LR_ADAM != -1:
            kwargs = {
                'lr':self.PPO_LR_ADAM,
            }

            if self.params.ADAM_EPS > 0:
                kwargs['eps'] = self.ADAM_EPS

            self.params.POLICY_ADAM = optim.Adam(self.policy_model.parameters(),
                                                 **kwargs)
        else:
            self.params.POLICY_ADAM = optim.SGD(self.policy_model.parameters(), lr=self.PPO_LR)

        # If using a time dependent value function, add one extra feature
        # for the time ratio t/T
        if time_in_state:
            self.params.NUM_FEATURES = self.NUM_FEATURES + 1

        # Value function optimization
        self.val_model = value_net_class(self.NUM_FEATURES, self.INITIALIZATION)
        self.val_opt = optim.Adam(self.val_model.parameters(), lr=self.VAL_LR, eps=1e-5) 
        assert self.policy_model.discrete == (self.AGENT_TYPE == "discrete")

        # Learning rate annealing
        # From OpenAI hyperparametrs:
        # Set adam learning rate to 3e-4 * alpha, where alpha decays from 1 to 0 over training
        if self.ANNEAL_LR:
            lam = lambda f: 1-f/self.TRAIN_STEPS
            ps = optim.lr_scheduler.LambdaLR(self.POLICY_ADAM, 
                                                    lr_lambda=lam)
            vs = optim.lr_scheduler.LambdaLR(self.val_opt, lr_lambda=lam)
            self.params.POLICY_SCHEDULER = ps
            self.params.VALUE_SCHEDULER = vs

        if store is not None:
            self.setup_stores(store)
        else:
            print("Not saving results to cox store.")

    

    def create_relaxed_model(self, time_in_state=False):
        # Create state perturbation model for robust PPO training.
        if isinstance(self.policy_model, CtsPolicy):
            from .convex_relaxation import RelaxedCtsPolicyForState
            relaxed_policy_model = RelaxedCtsPolicyForState(
                    self.NUM_FEATURES, self.NUM_ACTIONS, time_in_state=time_in_state,
                    activation=self.policy_activation, policy_model=self.policy_model)
            dummy_input1 = torch.randn(1, self.NUM_FEATURES)
            inputs = (dummy_input1, )
            self.relaxed_policy_model = BoundedModule(relaxed_policy_model, inputs)
            self.robust_eps_scheduler = LinearScheduler(self.params.ROBUST_PPO_EPS, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)
            if self.params.ROBUST_PPO_BETA_SCHEDULER_OPTS == "same":
                self.robust_beta_scheduler = LinearScheduler(self.params.ROBUST_PPO_BETA, self.params.ROBUST_PPO_EPS_SCHEDULER_OPTS)
            else:
                self.robust_beta_scheduler = LinearScheduler(self.params.ROBUST_PPO_BETA, self.params.ROBUST_PPO_BETA_SCHEDULER_OPTS)
        else:
            raise NotImplementedError

    """Initialize sarsa training."""
    def setup_sarsa(self, lr_schedule, eps_scheduler, beta_scheduler):
        # Create the Sarsa model, with S and A as the input.
        self.sarsa_model = ValueDenseNet(self.NUM_FEATURES + self.NUM_ACTIONS, self.INITIALIZATION)
        self.sarsa_opt = optim.Adam(self.sarsa_model.parameters(), lr=self.VAL_LR, eps=1e-5)
        self.sarsa_scheduler = optim.lr_scheduler.LambdaLR(self.sarsa_opt, lr_schedule)
        self.sarsa_eps_scheduler = eps_scheduler
        self.sarsa_beta_scheduler = beta_scheduler
        # Convert model with relaxation wrapper.
        dummy_input = torch.randn(1, self.NUM_FEATURES + self.NUM_ACTIONS)
        self.relaxed_sarsa_model = BoundedModule(self.sarsa_model, dummy_input)

    def setup_stores(self, store):
        # Logging setup
        self.store = store
        self.store.add_table('optimization', {
            'mean_reward':float,
            'final_value_loss':float,
            'mean_std':float
        })

        if self.advanced_logging:
            paper_constraint_cols = {
                'avg_kl':float,
                'max_kl':float,
                'max_ratio':float,
                'opt_step':int
            }

            value_cols = {
                'heldout_gae_loss':float,
                'heldout_returns_loss':float,
                'train_gae_loss':float,
                'train_returns_loss':float
            }

            weight_cols = {}
            for name, _ in self.policy_model.named_parameters():
                name += "."
                for k in ["l1", "l2", "linf", "delta_l1", "delta_l2", "delta_linf"]:
                    weight_cols[name + k] = float

            self.store.add_table('paper_constraints_train',
                                        paper_constraint_cols)
            self.store.add_table('paper_constraints_heldout',
                                        paper_constraint_cols)
            self.store.add_table('value_data', value_cols)
            self.store.add_table('weight_updates', weight_cols)

        if self.params.MODE == 'robust_ppo':
            robust_cols ={
                'eps': float,
                'beta': float,
                'kl': float,
                'surrogate': float,
                'entropy': float,
                'loss': float,
            }
            self.store.add_table('robust_ppo_data', robust_cols)


    def __getattr__(self, x):
        '''
        Allows accessing self.A instead of self.params.A
        '''
        if x == 'params':
            return {}
        try:
            return getattr(self.params, x)
        except KeyError:
            raise AttributeError(x)

    def advantage_and_return(self, rewards, values, not_dones):
        """
        Calculate GAE advantage, discounted returns, and 
        true reward (average reward per trajectory)

        GAE: delta_t^V = r_t + discount * V(s_{t+1}) - V(s_t)
        using formula from John Schulman's code:
        V(s_t+1) = {0 if s_t is terminal
                   {v_s_{t+1} if s_t not terminal and t != T (last step)
                   {v_s if s_t not terminal and t == T
        """
        assert shape_equal_cmp(rewards, values, not_dones)
        
        V_s_tp1 = ch.cat([values[:,1:], values[:, -1:]], 1) * not_dones
        deltas = rewards + self.GAMMA * V_s_tp1 - values

        # now we need to discount each path by gamma * lam
        advantages = ch.zeros_like(rewards)
        returns = ch.zeros_like(rewards)
        indices = get_path_indices(not_dones)
        for agent, start, end in indices:
            advantages[agent, start:end] = discount_path( \
                    deltas[agent, start:end], self.LAMBDA*self.GAMMA)
            returns[agent, start:end] = discount_path( \
                    rewards[agent, start:end], self.GAMMA)

        return advantages.clone().detach(), returns.clone().detach()

    def reset_envs(self, envs):
        '''
        Resets environments and returns initial state with shape:
        (# actors, 1, ... state_shape)
	    '''
        if self.CPU:
            return cpu_tensorize([env.reset() for env in envs]).unsqueeze(1)
        else:
            return cu_tensorize([env.reset() for env in envs]).unsqueeze(1)

    def multi_actor_step(self, actions, envs):
        '''
        Simulate a "step" by several actors on their respective environments
        Inputs:
        - actions, list of actions to take
        - envs, list of the environments in which to take the actions
        Returns:
        - completed_episode_info, a variable-length list of final rewards and episode lengths
            for the actors which have completed
        - rewards, a actors-length tensor with the rewards collected
        - states, a (actors, ... state_shape) tensor with resulting states
        - not_dones, an actors-length tensor with 0 if terminal, 1 otw
        '''
        normed_rewards, states, not_dones = [], [], []
        completed_episode_info = []
        for action, env in zip(actions, envs):
            gym_action = action[0].cpu().numpy()
            new_state, normed_reward, is_done, info = env.step(gym_action)
            if is_done:
                completed_episode_info.append(info['done'])
                new_state = env.reset()

            # Aggregate
            normed_rewards.append([normed_reward])
            not_dones.append([int(not is_done)])
            states.append([new_state])

        tensor_maker = cpu_tensorize if self.CPU else cu_tensorize
        data = list(map(tensor_maker, [normed_rewards, states, not_dones]))
        return [completed_episode_info, *data]

    def run_trajectories(self, num_saps, return_rewards=False, should_tqdm=False):
        """
        Resets environments, and runs self.T steps in each environment in 
        self.envs. If an environment hits a terminal state, the env is
        restarted and the terminal timestep marked. Each item in the tuple is
        a tensor in which the first coordinate represents the actor, and the
        second coordinate represents the time step. The third+ coordinates, if
        they exist, represent additional information for each time step.
        Inputs: None
        Returns:
        - rewards: (# actors, self.T)
        - not_dones: (# actors, self.T) 1 in timestep if terminal state else 0
        - actions: (# actors, self.T, ) indices of actions
        - action_logprobs: (# actors, self.T, ) log probabilities of each action
        - states: (# actors, self.T, ... state_shape) states
        """
        # Arrays to be updated with historic info
        envs = self.envs
        initial_states = self.reset_envs(envs)

        # Holds information (length and true reward) about completed episodes
        completed_episode_info = []
        traj_length = int(num_saps // self.NUM_ACTORS)

        shape = (self.NUM_ACTORS, traj_length)
        all_zeros = [ch.zeros(shape) for i in range(3)]
        rewards, not_dones, action_log_probs = all_zeros

        actions_shape = shape + (self.NUM_ACTIONS,)
        actions = ch.zeros(actions_shape)
        # Mean of the action distribution. Used for avoid unnecessary recomputation.
        action_means = ch.zeros(actions_shape)
        # Log Std of the action distribution.
        action_stds = ch.zeros(actions_shape)

        states_shape = (self.NUM_ACTORS, traj_length+1) + initial_states.shape[2:]
        states =  ch.zeros(states_shape)

        iterator = range(traj_length) if not should_tqdm else tqdm.trange(traj_length)

        assert self.NUM_ACTORS == 1

        states[:, 0, :] = initial_states
        last_states = states[:, 0, :]
        for t in iterator:
            # assert shape_equal([self.NUM_ACTORS, self.NUM_FEATURES], last_states)
            # Retrieve probabilities 
            # action_pds: (# actors, # actions), prob dists over actions
            # next_actions: (# actors, 1), indices of actions
            # next_action_probs: (# actors, 1), prob of taken actions
            last_states = self.apply_attack(last_states)
            # Note that for adversarial training, we use the state under perturbation to get the actions.
            # However in the trajectory we still save the state without perturbation as the true environment states are not perturbed.
            action_pds = self.policy_model(last_states)
            next_action_means, next_action_stds = action_pds
            next_actions = self.policy_model.sample(action_pds)
            next_action_log_probs = self.policy_model.get_loglikelihood(action_pds, next_actions)

            next_action_log_probs = next_action_log_probs.unsqueeze(1)
            # shape_equal([self.NUM_ACTORS, 1], next_action_log_probs)

            # if discrete, next_actions is (# actors, 1) 
            # otw if continuous (# actors, 1, action dim)
            next_actions = next_actions.unsqueeze(1)
            # if self.policy_model.discrete:
            #     assert shape_equal([self.NUM_ACTORS, 1], next_actions)
            # else:
            #     assert shape_equal([self.NUM_ACTORS, 1, self.policy_model.action_dim])

            ret = self.multi_actor_step(next_actions, envs)

            # done_info = List of (length, reward) pairs for each completed trajectory
            # (next_rewards, next_states, next_dones) act like multi-actor env.step()
            done_info, next_rewards, next_states, next_not_dones = ret
            # assert shape_equal([self.NUM_ACTORS, 1], next_rewards, next_not_dones)
            # assert shape_equal([self.NUM_ACTORS, 1, self.NUM_FEATURES], next_states)

            # If some of the actors finished AND this is not the last step
            # OR some of the actors finished AND we have no episode information
            if len(done_info) > 0 and (t != self.T - 1 or len(completed_episode_info) == 0):
                completed_episode_info.extend(done_info)

            # Update histories
            # each shape: (nact, t, ...) -> (nact, t + 1, ...)

            pairs = [
                (rewards, next_rewards),
                (not_dones, next_not_dones),
                (actions, next_actions), # The sampled actions.
                (action_means, next_action_means), # The Gaussian mean of actions.
                # (action_stds, next_action_stds), # The Gaussian std of actions, is a constant, no need to save.
                (action_log_probs, next_action_log_probs),
                (states, next_states),
            ]

            last_states = next_states[:, 0, :]
            for total, v in pairs:
                if total is states:
                    # Next states, stores in the next position.
                    total[:, t+1] = v
                else:
                    # The current action taken, and reward received.
                    total[:, t] = v

        # Calculate the average episode length and true rewards over all the trajectories
        infos = np.array(list(zip(*completed_episode_info)))
        # print(infos)
        if infos.size > 0:
            _, ep_rewards = infos
            avg_episode_length, avg_episode_reward = np.mean(infos, axis=1)
        else:
            ep_rewards = [-1]
            avg_episode_length = -1
            avg_episode_reward = -1

        # Last state is never acted on, discard
        states = states[:,:-1,:]
        trajs = Trajectories(rewards=rewards, 
            action_log_probs=action_log_probs, not_dones=not_dones, 
            actions=actions, states=states, action_means=action_means, action_std=next_action_stds)

        to_ret = (avg_episode_length, avg_episode_reward, trajs)
        if return_rewards:
            to_ret += (ep_rewards,)

        return to_ret

    """Conduct adversarial attack using value network."""
    def apply_attack(self, last_states):
        if self.params.ATTACK_RATIO < random.random():
            # Only attack a portion of steps.
            return last_states
        eps = self.params.ATTACK_EPS
        if eps == "same":
            eps = self.params.ROBUST_PPO_EPS
        else:
            eps = float(eps)
        steps = self.params.ATTACK_STEPS
        if self.params.ATTACK_METHOD == "critic":
            # Find a state that is close the last_states and decreases value most.
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                # Random start.
                noise = torch.empty_like(last_states).uniform_(-step_eps, step_eps)
                states = last_states + noise
                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        value = self.val_model(states).mean(dim=1)
                        value.backward()
                        update = states.grad.sign() * step_eps
                        # Clamp to +/- eps.
                        states.data = torch.min(torch.max(states.data - update, clamp_min), clamp_max)
                    self.val_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "random":
            # Apply an uniform random noise.
            noise = torch.empty_like(last_states).uniform_(-eps, eps)
            return (last_states + noise).detach()
        elif self.params.ATTACK_METHOD == "action":
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                # SGLD noise factor. We simply set beta=1.
                noise_factor = np.sqrt(2 * step_eps)
                noise = torch.randn_like(last_states) * noise_factor
                # The first step has gradient zero, so add the noise and projection directly.
                states = last_states + noise.sign() * step_eps
                # Current action at this state.
                old_action, old_stdev = self.policy_model(last_states)
                # Normalize stdev, avoid numerical issue
                old_stdev /= (old_stdev.mean())
                old_action = old_action.detach()
                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        action_change = (self.policy_model(states)[0] - old_action) / old_stdev
                        action_change = (action_change * action_change).sum(dim=1)
                        action_change.backward()
                        # Reduce noise at every step.
                        noise_factor = np.sqrt(2 * step_eps) / (i+2)
                        # Project noisy gradient to step boundary.
                        update = (states.grad + noise_factor * torch.randn_like(last_states)).sign() * step_eps
                        # Clamp to +/- eps.
                        states.data = torch.min(torch.max(states.data + update, clamp_min), clamp_max)
                    self.policy_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "sarsa" or self.params.ATTACK_METHOD == "sarsa+action":
            # Attack using a learned value network.
            assert self.params.ATTACK_SARSA_NETWORK is not None
            use_action = self.params.ATTACK_SARSA_ACTION_RATIO > 0 and self.params.ATTACK_METHOD == "sarsa+action"
            action_ratio = self.params.ATTACK_SARSA_ACTION_RATIO
            assert action_ratio >= 0 and action_ratio <= 1
            if not hasattr(self, "sarsa_network"):
                self.sarsa_network = ValueDenseNet(state_dim=self.NUM_FEATURES+self.NUM_ACTIONS, init="normal")
                print("Loading sarsa network", self.params.ATTACK_SARSA_NETWORK)
                sarsa_ckpt = torch.load(self.params.ATTACK_SARSA_NETWORK)
                sarsa_meta = sarsa_ckpt['metadata']
                sarsa_eps = sarsa_meta['sarsa_eps'] if 'sarsa_eps' in sarsa_meta else "unknown"
                sarsa_reg = sarsa_meta['sarsa_reg'] if 'sarsa_reg' in sarsa_meta else "unknown"
                sarsa_steps = sarsa_meta['sarsa_steps'] if 'sarsa_steps' in sarsa_meta else "unknown"
                print(f"Sarsa network was trained with eps={sarsa_eps}, reg={sarsa_reg}, steps={sarsa_steps}")
                if use_action:
                    print(f"objective: {1.0 - action_ratio} * sarsa + {action_ratio} * action_change")
                else:
                    print("Not adding action change objective.")
                self.sarsa_network.load_state_dict(sarsa_ckpt['state_dict'])
            if steps > 0:
                if self.params.ATTACK_STEP_EPS == "auto":
                    step_eps = eps / steps
                else:
                    step_eps = float(self.params.ATTACK_STEP_EPS)
                clamp_min = last_states - eps
                clamp_max = last_states + eps
                # Random start.
                noise = torch.empty_like(last_states).uniform_(-step_eps, step_eps)
                states = last_states + noise
                if use_action:
                    # Current action at this state.
                    old_action, old_stdev = self.policy_model(last_states)
                    old_stdev /= (old_stdev.mean())
                    old_action = old_action.detach()
                with torch.enable_grad():
                    for i in range(steps):
                        states = states.clone().detach().requires_grad_()
                        # This is the mean action...
                        actions = self.policy_model(states)[0]
                        value = self.sarsa_network(torch.cat((last_states, actions), dim=1)).mean(dim=1)
                        if use_action:
                            action_change = (actions - old_action) / old_stdev
                            # We want to maximize the action change, thus the minus sign.
                            action_change = -(action_change * action_change).mean(dim=1)
                            loss = action_ratio * action_change + (1.0 - action_ratio) * value
                        else:
                            action_change = 0.0
                            loss = value
                        loss.backward()
                        update = states.grad.sign() * step_eps
                        # Clamp to +/- eps.
                        states.data = torch.min(torch.max(states.data - update, clamp_min), clamp_max)
                    self.val_model.zero_grad()
                return states.detach()
            else:
                return last_states
        elif self.params.ATTACK_METHOD == "none":
            return last_states
        else:
            raise ValueError(f'Unknown attack method {self.params.ATTACK_METHOD}')


    """Run trajectories and return saps and values for each state."""
    def collect_saps(self, num_saps, should_log=True, return_rewards=False,
                     should_tqdm=False, test=False):
        with torch.no_grad():
            # Run trajectories, get values, estimate advantage
            output = self.run_trajectories(num_saps,
                                           return_rewards=return_rewards,
                                           should_tqdm=should_tqdm)

            if not return_rewards:
                avg_ep_length, avg_ep_reward, trajs = output
            else:
                avg_ep_length, avg_ep_reward, trajs, ep_rewards = output

            # No need to compute advantage function for testing.
            if not test:
                # If we are sharing weights between the policy network and 
                # value network, we use the get_value function of the 
                # *policy* to # estimate the value, instead of using the value
                # net
                if not self.SHARE_WEIGHTS:
                    values = self.val_model(trajs.states).squeeze(-1)
                else:
                    values = self.policy_model.get_value(trajs.states).squeeze(-1)

                # Calculate advantages and returns
                advantages, returns = self.advantage_and_return(trajs.rewards,
                                                values, trajs.not_dones)

                trajs.advantages = advantages
                trajs.returns = returns
                trajs.values = values

                assert shape_equal_cmp(trajs.advantages, 
                                trajs.returns, trajs.values)

            # Logging
            if should_log:
                msg = "Current mean reward: %f | mean episode length: %f"
                print(msg % (avg_ep_reward, avg_ep_length))
                if not test:
                    self.store.log_table_and_tb('optimization', {
                        'mean_reward': avg_ep_reward
                    })

            # Unroll the trajectories (actors, T, ...) -> (actors*T, ...)
            saps = trajs.unroll()

        to_ret = (saps, avg_ep_reward, avg_ep_length)
        if return_rewards:
            to_ret += (ep_rewards,)

        return to_ret


    def sarsa_steps(self, saps):
        # Begin advanged logging code
        assert saps.unrolled
        loss = torch.nn.SmoothL1Loss()
        action_std = torch.exp(self.policy_model.log_stdev).detach().requires_grad_(False)  # Avoid backprop twice.
        # We treat all value epochs as one epoch.
        self.sarsa_eps_scheduler.set_epoch_length(self.params.VAL_EPOCHS * self.params.NUM_MINIBATCHES)
        self.sarsa_beta_scheduler.set_epoch_length(self.params.VAL_EPOCHS * self.params.NUM_MINIBATCHES)
        # We count from 1.
        self.sarsa_eps_scheduler.step_epoch()
        self.sarsa_beta_scheduler.step_epoch()
        # saps contains state->action->reward and not_done.
        for i in range(self.params.VAL_EPOCHS):
            # Create minibatches with shuffuling
            state_indices = np.arange(saps.rewards.nelement())
            np.random.shuffle(state_indices)
            splits = np.array_split(state_indices, self.params.NUM_MINIBATCHES)

            # Minibatch SGD
            for selected in splits:
                def sel(*args):
                    return [v[selected] for v in args]

                self.sarsa_opt.zero_grad()
                sel_states, sel_actions, sel_rewards, sel_not_dones = sel(saps.states, saps.actions, saps.rewards, saps.not_dones)
                
                self.sarsa_eps_scheduler.step_batch()
                self.sarsa_beta_scheduler.step_batch()
                
                inputs = torch.cat((sel_states, sel_actions), dim=1)
                # action_diff = self.sarsa_eps_scheduler.get_eps() * action_std
                # inputs_lb = torch.cat((sel_states, sel_actions - action_diff), dim=1).detach().requires_grad_(False)
                # inputs_ub = torch.cat((sel_states, sel_actions + action_diff), dim=1).detach().requires_grad_(False)
                # bounded_inputs = BoundedTensor(inputs, ptb=PerturbationLpNorm(norm=np.inf, eps=None, x_L=inputs_lb, x_U=inputs_ub))
                bounded_inputs = BoundedTensor(inputs, ptb=PerturbationLpNorm(norm=np.inf, eps=self.sarsa_eps_scheduler.get_eps()))

                q = self.relaxed_sarsa_model(bounded_inputs).squeeze(-1)
                q_old = q[:-1]
                q_next = q[1:] * self.GAMMA * sel_not_dones[:-1] + sel_rewards[:-1]
                q_next = q_next.detach()
                # q_loss = (q_old - q_next).pow(2).sum(dim=-1).mean()
                q_loss = loss(q_old, q_next)
                # Compute the robustness regularization.
                if self.sarsa_eps_scheduler.get_eps() > 0 and self.params.SARSA_REG > 0:
                    beta = self.sarsa_beta_scheduler.get_eps()
                    ilb, iub = self.relaxed_sarsa_model.compute_bounds(IBP=True, method=None)
                    if beta < 1:
                        clb, cub = self.relaxed_sarsa_model.compute_bounds(IBP=False, method='backward')
                        lb = beta * ilb + (1 - beta) * clb
                        ub = beta * iub + (1 - beta) * cub
                    else:
                        lb = ilb
                        ub = iub
                    # Output dimension is 1. Remove the extra dimension and keep only the batch dimension.
                    lb = lb.squeeze(-1)
                    ub = ub.squeeze(-1)
                    diff = torch.max(ub - q, q - lb)
                    reg_loss = self.params.SARSA_REG * (diff * diff).mean()
                    sarsa_loss = q_loss + reg_loss
                    reg_loss = reg_loss.item()
                else:
                    reg_loss = 0.0
                    sarsa_loss = q_loss
                sarsa_loss.backward()
                self.sarsa_opt.step()
            print(f'q_loss={q_loss.item():.6g}, reg_loss={reg_loss:.6g}, sarsa_loss={sarsa_loss.item():.6g}')

        if self.ANNEAL_LR:
            self.sarsa_scheduler.step()
        # print('value:', self.val_model(saps.states).mean().item())

        return q_loss, q.mean()


    def take_steps(self, saps, logging=True, value_only=False):
        # Begin advanged logging code
        assert saps.unrolled
        should_adv_log = self.advanced_logging and \
                     self.n_steps % self.log_every == 0 and logging

        self.params.SHOULD_LOG_KL = self.advanced_logging and \
                        self.KL_APPROXIMATION_ITERS != -1 and \
                        self.n_steps % self.KL_APPROXIMATION_ITERS == 0
        store_to_pass = self.store if should_adv_log else None
        # End logging code

        if should_adv_log:
            # collect some extra trajactory for validation of KL and max KL.
            num_saps = saps.advantages.shape[0]
            val_saps = self.collect_saps(num_saps, should_log=False)[0]

            out_train = self.policy_model(saps.states)
            out_val = self.policy_model(val_saps.states)

            old_pds = select_prob_dists(out_train, detach=True)
            val_old_pds = select_prob_dists(out_val, detach=True)

        # Update the value function before unrolling the trajectories
        # Pass the logging data into the function if applicable
        val_loss = ch.tensor(0.0)
        if not self.SHARE_WEIGHTS:
            val_loss = value_step(saps.states, saps.returns, 
                saps.advantages, saps.not_dones, self.val_model,
                self.val_opt, self.params, store_to_pass).mean()

        if self.ANNEAL_LR:
            self.VALUE_SCHEDULER.step()

        if value_only:
            # Run the value iteration only. Return now.
            return val_loss

        if logging:
            self.store.log_table_and_tb('optimization', {
                'final_value_loss': val_loss
            })

        if self.MODE == 'robust_ppo' and logging:
            # Logging Robust PPO KL, entropy, etc.
            store_to_pass = self.store

        # Take optimizer steps
        args = [saps.states, saps.actions, saps.action_log_probs,
                saps.rewards, saps.returns, saps.not_dones, 
                saps.advantages, self.policy_model, self.params, 
                store_to_pass, self.n_steps]

        if self.MODE == 'robust_ppo' and isinstance(self.policy_model, CtsPolicy):
            args += [self.relaxed_policy_model, self.robust_eps_scheduler, self.robust_beta_scheduler]

        self.MAX_KL += self.MAX_KL_INCREMENT 
        if should_adv_log:
            # Save old parameter to investigate weight updates.
            old_parameter = copy.deepcopy(self.policy_model.state_dict())

        # Policy optimization step
        surr_loss = self.policy_step(*args).mean()

        # If the anneal_lr option is set, then we decrease the 
        # learning rate at each training step
        if self.ANNEAL_LR:
            self.POLICY_SCHEDULER.step()

        if should_adv_log:
            log_value_losses(self, val_saps, 'heldout')
            log_value_losses(self, saps, 'train')
            old_pds = saps.action_means, saps.action_std
            paper_constraints_logging(self, saps, old_pds,
                            table='paper_constraints_train')
            paper_constraints_logging(self, val_saps, val_old_pds,
                            table='paper_constraints_heldout')
            log_weight_updates(self, old_parameter, self.policy_model.state_dict())

            self.store['paper_constraints_train'].flush_row()
            self.store['paper_constraints_heldout'].flush_row()
            self.store['value_data'].flush_row()
            self.store['weight_updates'].flush_row()
        if self.params.MODE == 'robust_ppo':
            self.store['robust_ppo_data'].flush_row()

        return surr_loss, val_loss


    def train_step(self):
        '''
        Take a training step, by first collecting rollouts, then 
        calculating advantages, then taking a policy gradient step, and 
        finally taking a value function step.

        Inputs: None
        Returns: 
        - The current reward from the policy (per actor)
        '''
        print("-" * 80)
        start_time = time.time()

        num_saps = self.T * self.NUM_ACTORS
        saps, avg_ep_reward, avg_ep_length = self.collect_saps(num_saps)
        surr_loss, val_loss = self.take_steps(saps)
        # Logging code
        print("Surrogate Loss:", surr_loss.item(), 
                        "| Value Loss:", val_loss.item())
        print("Time elapsed (s):", time.time() - start_time)
        if not self.policy_model.discrete:
            mean_std = ch.exp(self.policy_model.log_stdev).mean()
            print("Agent stdevs: %s" % mean_std)
            self.store.log_table_and_tb('optimization', {
                'mean_std': mean_std
            })
        else:
            self.store['optimization'].update_row({
                'mean_std':np.nan
            })

        self.store['optimization'].flush_row()
        sys.stdout.flush()
        sys.stderr.flush()
        # End logging code

        self.n_steps += 1
        return avg_ep_reward

    def sarsa_step(self):
        '''
        Take a training step, by first collecting rollouts, and 
        taking a value function step.

        Inputs: None
        Returns: 
        - The current reward from the policy (per actor)
        '''
        print("-" * 80)
        start_time = time.time()

        num_saps = self.T * self.NUM_ACTORS
        saps, avg_ep_reward, avg_ep_length = self.collect_saps(num_saps, should_log=True, test=True)
        sarsa_loss, q = self.sarsa_steps(saps)
        print("Sarsa Loss:", sarsa_loss.item())
        print("Q:", q.item())
        print("Time elapsed (s):", time.time() - start_time)
        sys.stdout.flush()
        sys.stderr.flush()

        self.n_steps += 1
        return avg_ep_reward

    def run_test(self, max_len=2048, compute_bounds=False, use_full_backward=False):
        print("-" * 80)
        start_time = time.time()
        if compute_bounds:
            self.create_relaxed_model()
        #saps, avg_ep_reward, avg_ep_length = self.collect_saps(num_saps=None, should_log=True, test=True, num_episodes=num_episodes)
        with torch.no_grad():
            output = self.run_test_trajectories(max_len=max_len)
            ep_length, ep_reward, actions, action_means, states = output
            msg = "Episode reward: %f | episode length: %f"
            print(msg % (ep_reward, ep_length))
            if compute_bounds:
                eps = float(self.params.ROBUST_PPO_EPS) if self.params.ATTACK_EPS == "same" else float(self.params.ATTACK_EPS)
                kl_upper_bound = get_state_kl_bound(self.relaxed_policy_model, states, action_means,
                        eps=eps, beta=0.0,
                        stdev=self.policy_model.log_stdev, use_full_backward=use_full_backward).mean()
                kl_upper_bound = kl_upper_bound.item()
            else:
                kl_upper_bound = float("nan")
            # Unroll the trajectories (actors, T, ...) -> (actors*T, ...)
        return ep_length, ep_reward, actions.cpu().numpy(), action_means.cpu().numpy(), states.cpu().numpy(), kl_upper_bound

    def run_test_trajectories(self, max_len, should_tqdm=False):
        # Arrays to be updated with historic info
        envs = self.envs
        initial_states = self.reset_envs(envs)

        # Holds information (length and true reward) about completed episodes
        completed_episode_info = []

        shape = (1, max_len)
        rewards = ch.zeros(shape)

        actions_shape = shape + (self.NUM_ACTIONS,)
        actions = ch.zeros(actions_shape)
        # Mean of the action distribution. Used for avoid unnecessary recomputation.
        action_means = ch.zeros(actions_shape)

        states_shape = (1, max_len+1) + initial_states.shape[2:]
        states =  ch.zeros(states_shape)

        iterator = range(max_len) if not should_tqdm else tqdm.trange(max_len)


        states[:, 0, :] = initial_states
        last_states = states[:, 0, :]
        
        for t in iterator:
            if (t+1) % 100 == 0:
                print('Step {} '.format(t+1))
            # assert shape_equal([self.NUM_ACTORS, self.NUM_FEATURES], last_states)
            # Retrieve probabilities 
            # action_pds: (# actors, # actions), prob dists over actions
            # next_actions: (# actors, 1), indices of actions
            maybe_attacked_last_states = self.apply_attack(last_states)
            action_pds = self.policy_model(maybe_attacked_last_states)
            next_action_means, next_action_stds = action_pds
            # Double check if the attack is within eps range.
            if self.params.ATTACK_METHOD != "none":
                max_eps = (maybe_attacked_last_states - last_states).abs().max()
                attack_eps = float(self.params.ROBUST_PPO_EPS) if self.params.ATTACK_EPS == "same" else float(self.params.ATTACK_EPS)
                if max_eps > attack_eps + 1e-5:
                    raise RuntimeError(f"{max_eps} > {self.params.ATTACK_EPS}. Attack implementation has bug and eps is not correctly handled.")
            next_actions = self.policy_model.sample(action_pds)


            # if discrete, next_actions is (# actors, 1) 
            # otw if continuous (# actors, 1, action dim)
            next_actions = next_actions.unsqueeze(1)

            ret = self.multi_actor_step(next_actions, envs)

            # done_info = List of (length, reward) pairs for each completed trajectory
            # (next_rewards, next_states, next_dones) act like multi-actor env.step()
            done_info, next_rewards, next_states, next_not_dones = ret

            # Update histories
            # each shape: (nact, t, ...) -> (nact, t + 1, ...)

            pairs = [
                (rewards, next_rewards),
                (actions, next_actions), # The sampled actions.
                (action_means, next_action_means), # The sampled actions.
                (states, next_states),
            ]

            last_states = next_states[:, 0, :]
            for total, v in pairs:
                if total is states:
                    # Next states, stores in the next position.
                    total[:, t+1] = v
                else:
                    # The current action taken, and reward received.
                    total[:, t] = v
            
            # If some of the actors finished AND this is not the last step
            # OR some of the actors finished AND we have no episode information
            if len(done_info) > 0:
                completed_episode_info.extend(done_info)
                break

        if len(completed_episode_info) > 0:
            ep_length, ep_reward = completed_episode_info[0]
        else:
            ep_length = np.nan
            ep_reward = np.nan

        actions = actions[0][:t+1]
        action_means = action_means[0][:t+1]
        states = states[0][:t+1]

        to_ret = (ep_length, ep_reward, actions, action_means, states)
        
        
        return to_ret

    @staticmethod
    def agent_from_data(store, row, cpu, extra_params=None, override_params=None, excluded_params=None):
        '''
        Initializes an agent from serialized data (via cox)
        Inputs:
        - store, the name of the store where everything is logged
        - row, the exact row containing the desired data for this agent
        - cpu, True/False whether to use the CPU (otherwise sends to GPU)
        - extra_params, a dictionary of extra agent parameters. Only used
          when a key does not exist from the loaded cox store.
        - override_params, a dictionary of agent parameters that will override
          current agent parameters.
        - excluded_params, a dictionary of parameters that we do not copy or
          override.
        Outputs:
        - agent, a constructed agent with the desired initialization and
              parameters
        - agent_params, the parameters that the agent was constructed with
        '''

        ckpts = store['final_results']

        get_item = lambda x: list(row[x])[0]

        items = ['val_model', 'policy_model', 'val_opt', 'policy_opt']
        names = {i: get_item(i) for i in items}

        param_keys = list(store['metadata'].df.columns)
        param_values = list(store['metadata'].df.iloc[0,:])

        def process_item(v):
            try:
                return v.item()
            except:
                return v

        param_values = [process_item(v) for v in param_values]
        agent_params = {k:v for k, v in zip(param_keys, param_values)}

        if 'adam_eps' not in agent_params: 
            agent_params['adam_eps'] = 1e-5
        if 'cpu' not in agent_params:
            agent_params['cpu'] = cpu

        # Update extra params if they do not exist in current parameters.
        if extra_params is not None:
            for k in extra_params.keys():
                if k not in agent_params and k not in excluded_params:
                    print(f'adding key {k}={extra_params[k]}')
                    agent_params[k] = extra_params[k]
        if override_params is not None:
            for k in override_params.keys():
                if k not in excluded_params and override_params[k] is not None and override_params[k] != agent_params[k]:
                    print(f'overwriting key {k}: old={agent_params[k]}, new={override_params[k]}')
                    agent_params[k] = override_params[k]

        agent = Trainer.agent_from_params(agent_params)

        def load_state_dict(model, ckpt_name):
            mapper = ch.device('cuda:0') if not cpu else ch.device('cpu')
            state_dict = ckpts.get_state_dict(ckpt_name, map_location=mapper)
            model.load_state_dict(state_dict)

        load_state_dict(agent.policy_model, names['policy_model'])
        load_state_dict(agent.val_model, names['val_model'])
        if agent.ANNEAL_LR:
            agent.POLICY_SCHEDULER.last_epoch = get_item('iteration')
            agent.VALUE_SCHEDULER.last_epoch = get_item('iteration')
        load_state_dict(agent.POLICY_ADAM, names['policy_opt'])
        load_state_dict(agent.val_opt, names['val_opt'])
        agent.envs = ckpts.get_pickle(get_item('envs'))

        return agent, agent_params

    @staticmethod
    def agent_from_params(params, store=None):
        '''
        Construct a trainer object given a dictionary of hyperparameters.
        Trainer is in charge of sampling trajectories, updating policy network,
        updating value network, and logging.
        Inputs:
        - params, dictionary of required hyperparameters
        - store, a cox.Store object if logging is enabled
        Outputs:
        - A Trainer object for training a PPO/TRPO agent
        '''
        agent_policy = policy_net_with_name(params['policy_net_type'])
        agent_value = value_net_with_name(params['value_net_type'])

        advanced_logging = params['advanced_logging'] and store is not None
        log_every = params['log_every'] if store is not None else 0

        if params['cpu']:
            torch.set_num_threads(1)
        p = Trainer(agent_policy, agent_value, params, store, log_every=log_every,
                    advanced_logging=advanced_logging)

        return p
Exemple #15
0
class RobustDeterministicActorCriticNet(nn.Module, BaseNet):
    def __init__(self,
                 state_dim,
                 action_dim,
                 actor_network,
                 critic_network,
                 mini_batch_size,
                 actor_opt_fn,
                 critic_opt_fn,
                 robust_params=None):
        super(RobustDeterministicActorCriticNet, self).__init__()

        if robust_params is None:
            robust_params = {}
        self.use_loss_fusion = robust_params.get('use_loss_fusion', False) # Use loss fusion to reduce complexity for convex relaxation. Default is False.
        self.use_full_backward = robust_params.get('use_full_backward', False)
        if self.use_loss_fusion:
            # Use auto_LiRPA to compute the L2 norm directly.
            self.fc_action = model_mlp_any_with_loss(state_dim, actor_network, action_dim)
            modules = self.fc_action._modules
            # Auto LiRPA wrapper
            self.fc_action = BoundedModule(
                    self.fc_action, (torch.empty(size=(1, state_dim)), torch.empty(size=(1, action_dim))), device=Config.DEVICE)
            # self.fc_action._modules = modules
            for n in self.fc_action.nodes:
                # Find the tanh neuron in computational graph
                if isinstance(n, BoundTanh):
                    self.fc_action_after_tanh = n
                    self.fc_action_pre_tanh = n.inputs[0]
                    break
        else:
            # Fully connected layer with [state_dim, 400, 300, action_dim] neurons and ReLU activation function
            self.fc_action = model_mlp_any(state_dim, actor_network, action_dim)
            # auto_lirpa wrapper
            self.fc_action = BoundedModule(
                    self.fc_action, (torch.empty(size=(1, state_dim)), ), device=Config.DEVICE)

        # Fully connected layer with [state_dim + action_dim, 400, 300, 1]
        self.fc_critic = model_mlp_any(state_dim + action_dim, critic_network, 1)
        # auto_lirpa wrapper
        self.fc_critic = BoundedModule(
                self.fc_critic, (torch.empty(size=(1, state_dim + action_dim)), ), device=Config.DEVICE)

        self.actor_params = self.fc_action.parameters()
        self.critic_params = self.fc_critic.parameters()

        self.actor_opt = actor_opt_fn(self.actor_params)
        self.critic_opt = critic_opt_fn(self.critic_params)
        self.to(Config.DEVICE)
        # Create identity specification matrices
        self.actor_identity = torch.eye(action_dim).repeat(mini_batch_size,1,1).to(Config.DEVICE)
        self.critic_identity = torch.eye(1).repeat(mini_batch_size,1,1).to(Config.DEVICE)
        self.action_dim = action_dim
        self.state_dim = state_dim

    def forward(self, obs):
        phi = self.feature(obs)
        action = self.actor(phi)
        return action

    def feature(self, obs):
        # Not used, originally this is a feature extraction network
        return tensor(obs)

    def actor(self, phi):
        if self.use_loss_fusion:
            self.fc_action(phi, torch.zeros(size=phi.size()[:1] + (self.action_dim,), device=Config.DEVICE))
            return self.fc_action_after_tanh.forward_value
        else:
            return torch.tanh(self.fc_action(phi, method_opt="forward"))

    # Obtain element-wise lower and upper bounds for actor network through convex relaxations.
    def actor_bound(self, phi_lb, phi_ub, beta=1.0, eps=None, norm=np.inf, upper=True, lower=True, phi = None, center = None):
        if self.use_loss_fusion: # Use loss fusion (not typically enabled)
            assert center is not None
            ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub)
            x = BoundedTensor(phi, ptb)
            val = self.fc_action(x, center.detach())
            ilb, iub = self.fc_action.compute_bounds(IBP=True, method=None)
            if beta > 1e-10:
                clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward", bound_lower=False, bound_upper=True)
                ub = cub * beta + iub * (1.0 - beta)
                return ub
            else:
                return iub
        else:
            assert center is None
            # Invoke auto_LiRPA for convex relaxation.
            ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=phi_lb, x_U=phi_ub)
            x = BoundedTensor(phi, ptb)
            if self.use_full_backward:
                clb, cub = self.fc_action.compute_bounds(x=(x,), IBP=False, method="backward")
                return cub, clb
            else:
                ilb, iub = self.fc_action.compute_bounds(x=(x,), IBP=True, method=None)
                if beta > 1e-10:
                    clb, cub = self.fc_action.compute_bounds(IBP=False, method="backward")
                    ub = cub * beta + iub * (1.0 - beta)
                    lb = clb * beta + ilb * (1.0 - beta)
                    return ub, lb
                else:
                    return iub, ilb


    def critic(self, phi, a):
        return self.fc_critic(torch.cat([phi, a], dim=1), method_opt="forward")

    # Obtain element-wise lower and upper bounds for critic network through convex relaxations.
    def critic_bound(self, phi_lb, phi_ub, a_lb, a_ub, beta=1.0, eps=None, phi=None, action=None, norm=np.inf, upper=True, lower=True):
        x_L = torch.cat([phi_lb, a_lb], dim=1)
        x_U = torch.cat([phi_ub, a_ub], dim=1)
        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=x_L, x_U=x_U)
        x = BoundedTensor(torch.cat([phi, action], dim=1), ptb)
        ilb, iub = self.fc_critic.compute_bounds(x=(x,), IBP=True, method=None)
        if beta > 1e-10:
            clb, cub = self.fc_critic.compute_bounds(IBP=False, method="backward")
            ub = cub * beta + iub * (1.0 - beta)
            lb = clb * beta + ilb * (1.0 - beta)
            return ub, lb
        else:
            return iub, ilb
        
    def load_state_dict(self, state_dict, strict=True):
        action_dict = OrderedDict()
        critic_dict = OrderedDict()
        for k in state_dict.keys():
            if 'action' in k:
                pos = k.find('.') + 1
                action_dict[k[pos:]] = state_dict[k]
            if 'critic' in k:
                pos = k.find('.') + 1
                critic_dict[k[pos:]] = state_dict[k]
        # loading actor and critic networks separtely. this is requried for auto lirpa.
        self.fc_action.load_state_dict(action_dict)
        self.fc_critic.load_state_dict(critic_dict)

    def state_dict(self):
        # save actor and critic networks separtely. this is requried for auto lirpa.
        action_state_dict = self.fc_action.state_dict()
        critic_state_dict = self.fc_critic.state_dict()
        network_state_dict = OrderedDict()
        for k,v in action_state_dict.items():
            network_state_dict["fc_action."+k] = v
        for k,v in critic_state_dict.items():
            network_state_dict["fc_critic."+k] = v
        return network_state_dict