def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        obs_batch, action_mask = self._unpack_observation(obs_batch)

        # Compute actions
        with th.no_grad():
            q_values, hiddens = _mac(
                self.model, th.from_numpy(obs_batch),
                [th.from_numpy(np.array(s)) for s in state_batches])
            avail = th.from_numpy(action_mask).float()
            masked_q_values = q_values.clone()
            masked_q_values[avail == 0.0] = -float("inf")
            # epsilon-greedy action selector
            random_numbers = th.rand_like(q_values[:, :, 0])
            pick_random = (random_numbers < self.cur_epsilon).long()
            random_actions = Categorical(avail).sample().long()
            actions = (pick_random * random_actions +
                       (1 - pick_random) * masked_q_values.max(dim=2)[1])
            actions = actions.numpy()
            hiddens = [s.numpy() for s in hiddens]

        return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
Beispiel #2
0
def ring(size, sigma=1, tolerance=0, batch=None, dtype=None, device=None):
    '''Gaussian noise with `abs(L2_norm / sqrt(size) - sigma) <= tolerance`.

    Mean of squared L2 norm of Gaussian random vector = trace(covariance).

    We generate a standard independent Gaussian vector x with the given `size`
    with an L2 norm that is in the range [n - `tolerance`, n + `tolerance`],
    where `n = sigma * sqrt(size)`.

    Args:
        size: The size of the generated noise.
        sigma: The scalar input standard deviation.
        tolerance: The tolerance term (as described above).
        batch: The number of noises to generate.
        dtype: The data type.
        device: In which device.

    Returns:
        The generated noise.
    '''
    if batch is not None:
        size = [batch] + list(size)
    noise = torch.randn(size, dtype=dtype, device=device)
    flat = noise.view(1 if batch is None else batch, -1)
    norms = flat.norm(dim=1, keepdim=True)
    if tolerance != 0:
        sigma = torch.rand_like(norms) * (2 * tolerance) + (sigma - tolerance)
    flat.mul_(sigma * flat.size(1) ** 0.5 / norms)
    return noise
Beispiel #3
0
def sample_regions(lb: Tensor, ub: Tensor, K: int, depth: int) -> Tuple[Tensor, Tensor]:
    """ Uniformly sample K sub-regions with fixed width boundaries for each sub-region.
    :param lb: Lower bounds, batched
    :param ub: Upper bounds, batched
    :param K: how many pieces to sample
    :param depth: bisecting original region width @depth times for sampling
    """
    assert valid_lb_ub(lb, ub)
    assert K >= 1 and depth >= 1

    repeat_dims = [1] * (len(lb.size()) - 1)
    base = lb.repeat(K, *repeat_dims)  # repeat K times in the batch, preserving the rest dimensions
    orig_width = ub - lb

    try:
        piece_width = orig_width / (2 ** depth)
        # print('Piece width:', piece_width)
        avail_width = orig_width - piece_width
    except RuntimeError as e:
        print('Numerical error at depth', depth)
        raise e

    piece_width = piece_width.repeat(K, *repeat_dims)
    avail_width = avail_width.repeat(K, *repeat_dims)

    coefs = torch.rand_like(base)
    lefts = base + coefs * avail_width
    rights = lefts + piece_width
    return lefts, rights
Beispiel #4
0
 def _get_model(self, batch_shape, num_outputs, **tkwargs):
     train_x, train_y = _get_random_data(
         batch_shape=batch_shape, num_outputs=num_outputs, **tkwargs
     )
     train_yvar = (0.1 + 0.1 * torch.rand_like(train_y)) ** 2
     model = HeteroskedasticSingleTaskGP(
         train_X=train_x, train_Y=train_y, train_Yvar=train_yvar
     )
     mll = ExactMarginalLogLikelihood(model.likelihood, model).to(**tkwargs)
     fit_gpytorch_model(mll, options={"maxiter": 1})
     return model
Beispiel #5
0
def sample_points(lb: Tensor, ub: Tensor, K: int) -> Tensor:
    """ Uniformly sample K points for each region.
    :param lb: Lower bounds, batched
    :param ub: Upper bounds, batched
    :param K: how many pieces to sample
    """
    assert valid_lb_ub(lb, ub)
    assert K >= 1

    repeat_dims = [1] * (len(lb.size()) - 1)
    base = lb.repeat(K, *repeat_dims)  # repeat K times in the batch, preserving the rest dimensions
    width = (ub - lb).repeat(K, *repeat_dims)

    coefs = torch.rand_like(base)
    pts = base + coefs * width
    return pts
Beispiel #6
0
def single_test():
    # random input feature map size
    N = np.random.randint(low=1, high=128)
    C = np.random.randint(low=1, high=128)
    H = np.random.randint(low=7, high=128)
    W = np.random.randint(low=7, high=128)
    sigma = np.random.uniform(low=0.2, high=5.0, size=(N, C, H, W))

    norm = L2Norm2d(num_features=C).to(device)
    fnorm = fusedL2Norm2d(num_features=C).to(device)
    fnorm.weight.values = norm.weight.values
    fnorm.bias.values = norm.bias.values

    feature = torch.randn(size=(N, C, H, W), requires_grad=False) * sigma
    feature = feature.to(torch.float32)

    feat_in = feature.clone().to(device).requires_grad_(True)
    feat_in_f = feature.clone().to(device).requires_grad_(True)

    feat_out = norm(feat_in)
    feat_out_f = fnorm(feat_in_f)

    grad = torch.rand_like(feat_out)

    feat_out.backward(grad)
    feat_out_f.backward(grad)

    grad_in = feat_in.grad
    grad_in_f = feat_in_f.grad

    grad_w_in = norm.weight.grad
    grad_b_in = norm.bias.grad

    grad_w_in_f = fnorm.weight.grad
    grad_b_in_f = fnorm.bias.grad

    error_out = torch.abs(feat_out_f - feat_out)
    g_error_in = torch.abs(grad_in - grad_in_f)
    g_error_w = torch.abs(grad_w_in - grad_w_in_f)
    g_error_b = torch.abs(grad_b_in - grad_b_in_f)

    fnorm.eval()
    norm.eval()

    inf_feat_out = norm(feat_in)
    inf_feat_out_f = fnorm(feat_in_f)

    inf_error_out = torch.abs(inf_feat_out_f - inf_feat_out)

    passed = True
    max_error_out = torch.max(error_out).item()
    if max_error_out > 1e-5 or np.isnan(max_error_out):
        print(
            "[Forward] there are %d different entries in overall %d entries. The maximum difference is %f"
            % (torch.nonzero(error_out).size(0), error_out.numel(),
               max_error_out))
        passed = False

    max_g_error_in = torch.max(g_error_in).item()
    if max_g_error_in > 1e-5 or np.isnan(max_g_error_in):
        print(
            "[Grad in] there are %d different entries in overall %d entries. The maximum difference is %f"
            % (torch.nonzero(g_error_in).size(0), g_error_in.numel(),
               max_g_error_in))
        passed = False

    max_g_error_w = torch.max(g_error_w).item()
    if max_g_error_w > 1e-5 or np.isnan(max_g_error_w):
        print(
            "[Grad weight] there are %d different entries in overall %d entries. The maximum difference is %f"
            % (torch.nonzero(g_error_w).size(0), g_error_w.numel(),
               max_g_error_w))
        passed = False

    max_g_error_b = torch.max(g_error_b).item()
    if max_g_error_b > 1e-5 or np.isnan(max_g_error_b):
        print(
            "[Grad bias] there are %d different entries in overall %d entries. The maximum difference is %f"
            % (torch.nonzero(g_error_b).size(0), g_error_b.numel(),
               max_g_error_b))
        passed = False

    max_inf_error_out = torch.max(inf_error_out).item()
    if max_inf_error_out > 1e-5 or np.isnan(max_inf_error_out):
        print(
            "[Inference] there are %d different entries in overall %d entries. The maximum difference is %f"
            % (torch.nonzero(inf_error_out).size(0), inf_error_out.numel(),
               max_inf_error_out))
        passed = False

    return passed
Beispiel #7
0
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Part 1 intro https://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html#sphx-glr-beginner-blitz-tensor-tutorial-py
torch.empty(5, 3)
torch.empty((5, 3))
torch.rand(5, 3)
torch.zeros(5, 3, dtype=torch.long)
x = torch.tensor([5, 5, 3])
x

x = x.new_ones(5, 3, dtype=torch.double)
x = torch.rand_like(x, dtype=torch.float)
x

x.size()

y = torch.rand(5, 3)
x + y
torch.add(x, y)

result = torch.empty(5, 3)
torch.add(x, y, out=result)
result

y.add_(x)

x[:, 1]
Beispiel #8
0
    def train_rand_gen(self, loader, lr, wd, max_epochs, c, ld_step, ld_noise,
                       ld_step_size, clamp, alpha, save_interval, save_dir):
        r"""
            Running training for random generation task.

            Args:
                loader: The data loader for loading training samples. It is supposed to use dig.ggraph.dataset.QM9/ZINC250k
                    as the dataset class, and apply torch_geometric.data.DenseDataLoader to it to form the data loader.
                lr (float): The learning rate for training.
                wd (float): The weight decay factor for training.
                max_epochs (int): The maximum number of training epochs.
                c (float): The scaling hyperparameter for dequantization.
                ld_step (int): The number of iteration steps of Langevin dynamics.
                ld_noise (float): The standard deviation of the added noise in Langevin dynamics.
                ld_step_size (int): The step size of Langevin dynamics.
                clamp (bool): Whether to use gradient clamp in Langevin dynamics.
                alpha (float): The weight coefficient for loss function.
                save_interval (int): The frequency to save the model parameters to .pt files,
                    *e.g.*, if save_interval=2, the model parameters will be saved for every 2 training epochs.
                save_dir (str): the directory to save the model parameters.
        """
        parameters = self.energy_function.parameters()
        optimizer = Adam(parameters,
                         lr=lr,
                         betas=(0.0, 0.999),
                         weight_decay=wd)

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for epoch in range(max_epochs):
            t_start = time.time()
            losses_reg = []
            losses_en = []
            losses = []
            for i, batch in enumerate(tqdm(loader)):
                ### Dequantization
                pos_x = batch.x.to(self.device).to(dtype=torch.float32)
                pos_x += c * torch.rand_like(pos_x, device=self.device)
                pos_adj = batch.adj.to(self.device).to(dtype=torch.float32)
                pos_adj += c * torch.rand_like(pos_adj, device=self.device)

                ### Langevin dynamics
                neg_x = torch.rand_like(pos_x, device=self.device) * (1 + c)
                neg_adj = torch.rand_like(pos_adj, device=self.device)

                pos_adj = rescale_adj(pos_adj)
                neg_x.requires_grad = True
                neg_adj.requires_grad = True

                requires_grad(parameters, False)
                self.energy_function.eval()

                noise_x = torch.randn_like(neg_x, device=self.device)
                noise_adj = torch.randn_like(neg_adj, device=self.device)
                for k in range(ld_step):

                    noise_x.normal_(0, ld_noise)
                    noise_adj.normal_(0, ld_noise)
                    neg_x.data.add_(noise_x.data)
                    neg_adj.data.add_(noise_adj.data)

                    neg_out = self.energy_function(neg_adj, neg_x)
                    neg_out.sum().backward()
                    if clamp:
                        neg_x.grad.data.clamp_(-0.01, 0.01)
                        neg_adj.grad.data.clamp_(-0.01, 0.01)

                    neg_x.data.add_(neg_x.grad.data, alpha=ld_step_size)
                    neg_adj.data.add_(neg_adj.grad.data, alpha=ld_step_size)

                    neg_x.grad.detach_()
                    neg_x.grad.zero_()
                    neg_adj.grad.detach_()
                    neg_adj.grad.zero_()

                    neg_x.data.clamp_(0, 1 + c)
                    neg_adj.data.clamp_(0, 1)

                ### Training by backprop
                neg_x = neg_x.detach()
                neg_adj = neg_adj.detach()
                requires_grad(parameters, True)
                self.energy_function.train()

                self.energy_function.zero_grad()

                pos_out = self.energy_function(pos_adj, pos_x)
                neg_out = self.energy_function(neg_adj, neg_x)

                loss_reg = (pos_out**2 + neg_out**2
                            )  # energy magnitudes regularizer
                loss_en = pos_out - neg_out  # loss for shaping energy function
                loss = loss_en + alpha * loss_reg
                loss = loss.mean()
                loss.backward()
                clip_grad(parameters, optimizer)
                optimizer.step()

                losses_reg.append(loss_reg.mean())
                losses_en.append(loss_en.mean())
                losses.append(loss)

            t_end = time.time()

            ### Save checkpoints
            if (epoch + 1) % save_interval == 0:
                torch.save(
                    self.energy_function.state_dict(),
                    os.path.join(save_dir, 'epoch_{}.pt'.format(epoch + 1)))
                print('Saving checkpoint at epoch ', epoch + 1)
                print('==========================================')
            print(
                'Epoch: {:03d}, Loss: {:.6f}, Energy Loss: {:.6f}, Regularizer Loss: {:.6f}, Sec/Epoch: {:.2f}'
                .format(epoch + 1, (sum(losses) / len(losses)).item(),
                        (sum(losses_en) / len(losses_en)).item(),
                        (sum(losses_reg) / len(losses_reg)).item(),
                        t_end - t_start))
            print('==========================================')
Beispiel #9
0
def a(seed, k):
    a = torch.empty(100, 10, dtype=k.dtype).normal_(-1, 1)
    a /= a.norm(dim=-1, keepdim=True)
    a *= (torch.rand_like(k) * k) ** 0.5
    return lorentz.math.project(a, k=k)
Beispiel #10
0
                # add l1/l2 regularization
                if args.l1_alpha > 0:
                    l1_penalty = l1_regularization(generator, "weight")
                    g_loss = g_loss + args.l1_alpha * l1_penalty
                if args.l2_alpha > 0:
                    l2_penalty = l2_regularization(generator, "weight")
                    g_loss = g_loss + args.l2_alpha * l2_penalty
                if args.dist_alpha > 0:
                    label_slices, real_data, fake_data = get_speech_slices(
                        label, results, frame_number)
                    # get discriminative loss for generator
                    fake_logits, fake_feat = discriminator((fake_data, ))
                    real_logits, real_feat = discriminator(
                        (real_data.detach(), ))
                    rand_logits, rand_feat = discriminator(
                        (torch.rand_like(real_data) * 2. - 1., ))
                    dist_loss = (fake_logits - real_logits).pow(2.).mean()
                    g_loss = g_loss + dist_loss * args.dist_alpha

                if args.feat_alpha > 0:
                    label_slices, real_data, fake_data = get_speech_slices(
                        label, results, frame_number)
                    # get discriminative loss for generator
                    fake_logits, fake_feat = discriminator((fake_data, ))
                    real_logits, real_feat = discriminator(
                        (real_data.detach(), ))
                    feat_loss = (fake_feat - real_feat).pow(2.).mean()
                    g_loss = g_loss + feat_loss * args.feat_alpha

                recon_loss = reconstruction_loss(
                    predict, target, batch_mask, backward=False,
def gumbel_softmax(logits, tau, eps=1e-8):
    U = torch.rand_like(logits)
    gumbel = -torch.log(-torch.log(U + eps) + eps)
    y = logits + gumbel
    y = F.softmax(y / tau, dim=1)
    return y
Beispiel #12
0
 def get_new_weights(self, m, rate):
     i = (torch.rand_like(m.weight.data) < rate)
     v = torch.zeros_like(m.weight.data)
     return i, v
    def loadAndAdaptModel(self, stateDict, data, updateStateFun):
        '''
            Loads a model and a labelclass map from a given "stateDict".
            First calls the parent implementation to obtain a default
            model, then checks for new label classes and modifies the
            model's classification head accordingly.
            TODO: implement advanced modifiers:
            1. Weighted linear combination of images with new annotations
               present
            2. Weighted linear combination according to similarity of name
               of new classes w.r.t. existing ones (e.g. using Word2Vec)
            
            For now, only the smallest existing class weights are used
            and duplicated.
        '''
        model, stateDict, newClasses, _ = self.initializeModel(stateDict, data)
        
        # modify model weights to accept new label classes
        if len(newClasses):

            # create temporary labelclassMap for new classes
            lcMap_new = dict(zip(newClasses, list(range(len(newClasses)))))

            # create vector of label classes
            classVector = len(stateDict['labelclassMap']) * [None]
            for (key, index) in zip(stateDict['labelclassMap'].keys(), stateDict['labelclassMap'].values()):
                classVector[index] = key

            weights = model.sem_seg_head.predictor.weight
            biases = model.sem_seg_head.predictor.bias
            numClasses_orig = len(biases)

            # create weights and biases for new classes
            if True:        #TODO: add flags in config file about strategy
                weights_copy = weights.clone()
                biases_copy = biases.clone()

                #TODO: we currently have no indexing possibilities to retrieve images with correct labels...
                # correlations = self.calculateClassCorrelations(model, lcMap_new, range(numClasses_orig), newClasses, updateStateFun, 128)    #TODO: num images
                
                # use alternative solution: choose random class
                randomOrder = torch.randperm(numClasses_orig)
                for cl in range(len(newClasses)):
                    newWeight = weights_copy[randomOrder[cl],...]
                    newBias = biases_copy[randomOrder[cl]]

                    # add a bit of noise
                    newWeight += (0.5 - torch.rand_like(newWeight)) * 0.5 * torch.std(weights_copy)
                    newBias += (0.5 - torch.rand_like(newBias)) * 0.5 * torch.std(biases_copy)

                    # prepend
                    weights = torch.cat((newWeight.unsqueeze(0), weights), 0)
                    biases = torch.cat((newBias.unsqueeze(0), biases), 0)
                    classVector.insert(0, newClasses[cl])

            # remove old classes
            # valid = torch.ones(len(biases), dtype=torch.bool)
            classMap_updated = {}
            index_updated = 0
            for idx, clName in enumerate(classVector):
                # if clName not in data['labelClasses']:
                #     valid[idx] = 0
                # else:
                if True:    # we don't remove old classes anymore (TODO: flag in configuration)
                    classMap_updated[clName] = index_updated
                    index_updated += 1

            # weights = weights[valid,...]
            # biases = biases[valid,...]

            # apply updated weights and biases
            model.sem_seg_head.predictor.weight = torch.nn.Parameter(weights)
            model.sem_seg_head.predictor.bias = torch.nn.Parameter(biases)

            stateDict['labelclassMap'] = classMap_updated
                
            print(f'Neurons for {len(newClasses)} new label classes added to DeepLabV3+ model.')

        # finally, update model and config
        stateDict['detectron2cfg'].MODEL.SEM_SEG_HEAD.NUM_CLASSES = len(stateDict['labelclassMap'])
        model.num_classes = len(stateDict['labelclassMap'])
        return model, stateDict
Beispiel #14
0
def test_unified_scales_are_identical_in_onnx(tmp_path):
    #pylint:disable=no-member
    nncf_config = get_quantization_config_without_range_init(model_size=1)
    nncf_config["compression"]["quantize_outputs"] = True
    nncf_config["input_info"] = [
        {
            "sample_size": [1, 1, 1, 2],
        },
    ]
    nncf_config["target_device"] = "VPU"

    compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(
        SimplerModelForUnifiedScalesTesting(),
        nncf_config)

    with torch.no_grad():
        for quant_info in compression_ctrl.non_weight_quantizers.values():
            quant_info.quantizer_module_ref.scale *= torch.abs(torch.rand_like(quant_info.quantizer_module_ref.scale))

    test_input1 = torch.ones([1, 1, 1, 2])
    compressed_model.forward(test_input1)

    onnx_path = tmp_path / "model.onnx"
    compression_ctrl.export_model(onnx_path)

    onnx_model = onnx.load(onnx_path)

    def get_fq_nodes(onnx_model: onnx.ModelProto) -> List[onnx.NodeProto]:
        retval = []
        for node in onnx_model.graph.node:
            if str(node.op_type) == "FakeQuantize":
                retval.append(node)
        return retval

    def immediately_dominates_add_or_mul(node: onnx.NodeProto, graph: onnx.GraphProto) -> bool:
        if len(node.output) != 1:
            return False
        output_tensor_id = node.output[0]
        matches = [x for x in graph.node if output_tensor_id in x.input]
        for match in matches:
            if match.op_type in ["Add", "Mul"]:
                return True
        return False

    def get_successor(node: onnx.NodeProto, graph: onnx.GraphProto) -> onnx.NodeProto:
        assert len(node.output) == 1  # Only single-output nodes are supported in this func
        for target_node in graph.node:
            if node.output[0] in target_node.input:
                return target_node
        return None

    def group_nodes_by_output_target(nodes: List[onnx.NodeProto], graph: onnx.GraphProto) -> List[List[onnx.NodeProto]]:
        output_nodes = {}  # type: Dict[str, List[onnx.NodeProto]]
        for node in nodes:
            target_node_name = get_successor(node, graph).name
            if target_node_name not in output_nodes:
                output_nodes[target_node_name] = []
            output_nodes[target_node_name].append(node)
        return list(output_nodes.values())

    def resolve_constant_node_inputs_to_values(node: onnx.NodeProto, graph: onnx.GraphProto) -> \
            Dict[str, onnx.AttributeProto]:
        retval = {}
        for input_ in node.input:
            constant_input_nodes = [x for x in graph.node if input_ in x.output and x.op_type == "Constant"]
            for constant_input_node in constant_input_nodes:
                assert len(constant_input_node.attribute) == 1
                val = constant_input_node.attribute[0]
                retval[input_] = numpy_helper.to_array(val.t)
        return retval

    fq_nodes = get_fq_nodes(onnx_model)
    eltwise_predicate = partial(immediately_dominates_add_or_mul, graph=onnx_model.graph)
    eltwise_fq_nodes = list(filter(eltwise_predicate, fq_nodes))
    fq_nodes_grouped_by_output = group_nodes_by_output_target(eltwise_fq_nodes, onnx_model.graph)

    for unified_scale_group in fq_nodes_grouped_by_output:
        inputs = [resolve_constant_node_inputs_to_values(fq_node, onnx_model.graph) for fq_node in unified_scale_group]
        for inputs_dict in inputs[1:]:
            curr_values = list(inputs_dict.values())
            ref_values = list(inputs[0].values())
            assert curr_values == ref_values  # All inputs for unified scale quantizers must be equal
Beispiel #15
0
import torch

## not initialized Tensor
x = torch.empty(5, 3)
print(x)

## randomly initialized Tensor
y = torch.rand(5, 3)
print(y)

## dtype = long, initialized to 0
z = torch.zeros(5, 3, dtype=torch.long)
print(z)

k = torch.tensor([5.5, 3])
print(k)

k = k.new_ones(5, 3, dtype=torch.double)
print(k)

k = torch.rand_like(k, dtype=torch.float)
print(k)

print(x.size())

y = torch.rand(5, 3)
print(k + y)

Beispiel #16
0
def gumbel_softmax(p, eps=1e-20):
    u = torch.rand_like(p)
    return torch.softmax(p - torch.log(-torch.log(u + eps) + eps), dim=-1)
Beispiel #17
0
print(tensor_from_list)

# from numpy
np_data = np.array(list_data)
tensor_from_np = torch.from_numpy(np_data)
print(tensor_from_np)

# specify dtype
tensor_from_list_int32 = torch.tensor(list_data, dtype=torch.int32)
print(tensor_from_list_int32)

# from others (param type must be torch.Tensor)
tensor_ones = torch.ones_like(tensor_from_list)
print(f"Ones like: \n{tensor_ones}")

tensor_rand = torch.rand_like(tensor_from_list, dtype=torch.float)
print(f"Rand like: \n{tensor_rand}")

# clear


# %%
# from shape

shape = (3, 4, )  # why comma
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)
print(f"Random Tensor: \n {rand_tensor}")
print(f"Ones Tensor: \n {ones_tensor}")
print(f"Zeros Tensor: \n {zeros_tensor}")
Beispiel #18
0
 def fn_test_rand(x, y):
     r = torch.rand_like(y)
     return r * x + x
Beispiel #19
0
 def create(self, x):
     return x * x + x + torch.rand_like(x)
Beispiel #20
0
 def exchange(self, m1, m2):
     if type(m1) in self.target_modules:
         p = 0.5
         i = torch.rand_like(m1.weight.data) < p
         m1.weight.data[i] = m2.weight.data[i]
         m2.weight.data[1 - i] = m1.weight.data[1 - i]
Beispiel #21
0
def gumbel_like(x):
    return -(-torch.rand_like(x).log()).log().requires_grad_(False)
def main():
    # device = input('输入运行的设备,例如“cpu”或“cuda:0”  ')
    # dataset_dir = input('输入保存MNIST数据集的位置,例如“./”  ')
    # class_num = int(input('输入class_num,例如“10”  '))
    # lr = float(input('输入学习率,例如“1e-3”  '))
    # phase = input('输入算法阶段,例如“train/BIM”  ')

    device = 'cuda:2'
    dataset_dir = '../../dataset/'
    class_num = 10
    lr = 1e-4
    phase = 'BIM'

    torch.cuda.empty_cache()

    if phase == 'train':
        # model_dir = input('输入保存模型文件的位置,例如“./”  ')
        # batch_size = int(input('输入batch_size,例如“64”  '))
        # train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100”  '))
        # log_dir = input('输入保存tensorboard日志文件的位置,例如“./”  ')

        model_dir = './models/'
        batch_size = 128
        train_epoch = 9999999
        log_dir = './logs/'

        writer = SummaryWriter(log_dir)

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        train_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(
                root=dataset_dir,
                train=True,
                transform=transform_train,
                download=True),
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)
            
        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(
                root=dataset_dir,
                train=False,
                transform=transform_test,
                download=True),
            batch_size=batch_size,
            shuffle=True,
            drop_last=False)

        net = Net().to(device)
        
        optimizer = torch.optim.Adam(net.parameters(), lr=lr)
        
        train_times = 0
        best_epoch = 0
        max_correct_sum = 0

        for epoch in range(1, train_epoch + 1):
            net.train()

            for X, y in train_data_loader:
                img, label = X.to(device), y.to(device)
                
                optimizer.zero_grad()

                output = net(img)

                # loss =  F.mse_loss(output, F.one_hot(label, class_num).float())
                loss = F.cross_entropy(output, label)

                loss.backward()
                optimizer.step()

                correct_rate = (output.max(1)[1] == label).float().mean().item()
                writer.add_scalar('train_correct_rate', correct_rate, train_times)
                # if train_times % 1024 == 0:
                #     print(device, dataset_dir, batch_size, lr, train_epoch, log_dir)
                #     print(sys.argv, 'train_times', train_times, 'train_correct_rate', correct_rate)
                train_times += 1

            net.eval()

            with torch.no_grad():
                test_sum = 0
                correct_sum = 0
                for X, y in test_data_loader:
                    img, label = X.to(device), y.to(device)
                    
                    output = net(img)

                    correct_sum += (output.max(1)[1] == label).float().sum().item()
                    test_sum += label.numel()

                writer.add_scalar('test_correct_rate', correct_sum / test_sum, train_times)

                print('epoch', epoch, 'test_correct_rate', correct_sum / test_sum)

                if correct_sum > max_correct_sum:
                    max_correct_sum = correct_sum
                    torch.save(net.state_dict(), model_dir + 'ann_best_%d.pth' % (epoch))
                    if best_epoch > 0:
                        os.system('rm %sann_best_%d.pth' % (model_dir, best_epoch))
                    best_epoch = epoch

    elif phase == 'BIM':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')
        # iter_num = int(input('输入对抗攻击的迭代次数,例如“25”  '))
        # eta = float(input('输入对抗攻击学习率,例如“0.05”  '))
        # attack_type = input('输入攻击类型,例如“UT/T”  ')
        # clip_eps = float(input('输入截断eps,例如“0.01”  '))

        model_path = './models/cifar10_ann_v2.pth'
        iter_num = 25
        eta = 0.003
        attack_type = 'UT'

        clip_eps = 0.06

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(
                root=dataset_dir,
                train=False,
                transform=transform_test,
                download=True),
            batch_size=1,
            shuffle=False,
            drop_last=False)

        p_max = transform_test(np.ones((32, 32, 3))).to(device)
        p_min = transform_test(np.zeros((32, 32, 3))).to(device)

        net = Net().to(device)
        net.load_state_dict(torch.load(model_path))

        mean_p = 0.0
        test_sum = 0
        success_sum = 0

        if attack_type == 'UT':
            for X, y in test_data_loader:
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)
                img.requires_grad = True

                test_sum += 1

                print('Img %d' % test_sum)

                net.train()

                for it in range(iter_num):
                    output = net(img).unsqueeze(0)

                    # loss = F.mse_loss(output, F.one_hot(label, class_num).float())
                    loss = F.cross_entropy(output, label)

                    loss.backward()

                    img_grad = torch.sign(img.grad.data)

                    img_adv = clip_by_tensor(img + eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max)

                    img = Variable(img_adv, requires_grad=True)
                
                net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l_norm = torch.max(torch.abs(img_diff)).item()
                    print('Perturbation: %f' % l_norm)

                    mean_p += l_norm

                    output = net(img).unsqueeze(0)

                    attack_flag = (output.max(1)[1] != label).float().sum().item()
                    success_sum += attack_flag

                    if attack_flag > 0.5:
                        print('Attack Success')
                    else:
                        print('Attack Failure')

                if test_sum >= 250:
                    mean_p /= 250
                    break
        else:
            for X, y in test_data_loader:
                for i in range(1, class_num):
                    img, label = X.to(device), y.to(device)
                    img_ori = torch.rand_like(img).copy_(img)
                    img.requires_grad = True
                    
                    target_label = (label + i) % class_num

                    test_sum += 1

                    net.train()

                    for it in range(iter_num):
                        output = net(img).unsqueeze(0)

                        # loss = F.mse_loss(output, F.one_hot(target_label, class_num).float())
                        loss = F.cross_entropy(output, target_label)

                        loss.backward()

                        img_grad = torch.sign(img.grad.data)

                        img_adv = clip_by_tensor(img - eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max)

                        img = Variable(img_adv, requires_grad=True)
                    
                    net.eval()

                    with torch.no_grad():
                        img_diff = img - img_ori

                        l_norm = torch.max(torch.abs(img_diff)).item()
                        print('Perturbation: %f' % l_norm)

                        mean_p += l_norm

                        output = net(img).unsqueeze(0)

                        attack_flag = (output.max(1)[1] == target_label).float().sum().item()
                        success_sum += attack_flag

                        if attack_flag > 0.5:
                            print('Attack Success')
                        else:
                            print('Attack Failure')

                        '''
                        samples = img.permute(0, 2, 3, 1).data.cpu().numpy()

                        im = np.repeat(samples[0], 3, axis=2)
                        im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item())
                        print(im_path)
                        print(output)
                        plt.imsave(im_path, im)
                        '''

                if test_sum >= 270:
                    mean_p /= 270
                    break
        
        print('Mean Perturbation: %.3f' % mean_p)
        print('success_sum: %d' % success_sum)
        print('test_sum: %d' % test_sum)
        print('success_rate: %.2f%%' % (100 * success_sum / test_sum))

    elif phase == 'PGD':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')
        # iter_num = int(input('输入对抗攻击的迭代次数,例如“25”  '))
        # eta = float(input('输入对抗攻击学习率,例如“0.05”  '))
        # max_eps = float(input('输入最大扰动幅度,例如“4.0”  '))

        model_path = './models/cifar10_ann_v1.pth'
        iter_num = 100
        max_eps = 0.3
        eta = 0.01
        attack_type = 'T'

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(
                root=dataset_dir,
                train=False,
                transform=transform_test,
                download=True),
            batch_size=1,
            shuffle=True,
            drop_last=False)

        net = Net().to(device)
        net.load_state_dict(torch.load(model_path))

        mean_p2 = 0.0
        mean_pmax = 0.0
        test_sum = 0
        success_sum = 0

        for X, y in test_data_loader:
            for i in range(1, class_num):
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)
                img.requires_grad = True
                    
                target_label = (label + i) % class_num

                test_sum += 1

                net.train()

                for it in range(iter_num):
                    output = net(img).unsqueeze(0)

                    # loss = F.mse_loss(output, F.one_hot(target_label, class_num).float())
                    loss = F.cross_entropy(output, target_label)

                    loss.backward()

                    img_grad = torch.sign(img.grad.data)

                    perturbation = img - eta * img_grad - img_ori

                    # l2_norm = torch.norm(perturbation.view(perturbation.size()[0], -1), dim=1).item()
                    lmax_norm = torch.max(torch.abs(perturbation)).item()
                    
                    # if l2_norm > max_eps:
                        # perturbation = perturbation * max_eps / l2_norm

                    if lmax_norm > max_eps:
                        perturbation = perturbation * max_eps / lmax_norm

                    img_adv = torch.clamp(img_ori + perturbation, 0.0, 1.0)

                    img = Variable(img_adv, requires_grad=True)
                    
                net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1), dim=1).item()
                    lmax_norm = torch.max(torch.abs(img_diff)).item()
                    
                    print('L2 Perturbation: %f' % l2_norm)
                    print('Lmax Perturbation: %f' % lmax_norm)

                    mean_p2 += l2_norm
                    mean_pmax += lmax_norm

                    output = net(img).unsqueeze(0)

                    attack_flag = (output.max(1)[1] == target_label).float().sum().item()
                    success_sum += attack_flag

                    if attack_flag > 0.5:
                        print('Attack Success')
                    else:
                        print('Attack Failure')

                    '''
                    samples = img.permute(0, 2, 3, 1).data.cpu().numpy()

                    im = np.repeat(samples[0], 3, axis=2)
                    im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item())
                    print(im_path)
                    print(output)
                    plt.imsave(im_path, im)
                    '''

                if test_sum >= 270:
                    mean_p2 /= 270
                    mean_pmax /= 270
                    break
        
        print('Mean L2 Perturbation: %.2f' % mean_p2)
        print('Mean Lmax Perturbation: %.2f' % mean_pmax)
        print('success_sum: %d' % success_sum)
        print('test_sum: %d' % test_sum)
        print('success_rate: %.2f%%' % (100 * success_sum / test_sum))

    elif phase == 'OPA':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')

        model_path = './models/cifar10_ann_v1.pth'
        attack_type = 'T'

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(
                root=dataset_dir,
                train=False,
                transform=transform_test,
                download=True),
            batch_size=1,
            shuffle=True,
            drop_last=False)

        net = Net().to(device)
        net.load_state_dict(torch.load(model_path))

        mean_p = 0.0
        test_sum = 0
        success_sum = 0

        for X, y in test_data_loader:
            for i in range(1, class_num):
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)
                    
                target_label = (label + i) % class_num

                test_sum += 1

                net.eval()

                for it in range(iter_num):
                    output = net(img).unsqueeze(0)

                    # loss = F.mse_loss(output, F.one_hot(target_label, class_num).float())
                    loss = F.cross_entropy(output, target_label)

                    loss.backward()

                    img_grad = torch.sign(img.grad.data)

                    perturbation = img - eta * img_grad - img_ori

                    # l2_norm = torch.norm(perturbation.view(perturbation.size()[0], -1), dim=1).item()
                    lmax_norm = torch.max(torch.abs(perturbation)).item()
                    
                    # if l2_norm > max_eps:
                        # perturbation = perturbation * max_eps / l2_norm

                    if lmax_norm > max_eps:
                        perturbation = perturbation * max_eps / lmax_norm

                    img_adv = torch.clamp(img_ori + perturbation, 0.0, 1.0)

                    img = Variable(img_adv, requires_grad=True)
                    
                net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1), dim=1).item()
                    lmax_norm = torch.max(torch.abs(img_diff)).item()
                    
                    print('L2 Perturbation: %f' % l2_norm)
                    print('Lmax Perturbation: %f' % lmax_norm)

                    mean_p2 += l2_norm
                    mean_pmax += lmax_norm

                    output = net(img).unsqueeze(0)

                    attack_flag = (output.max(1)[1] == target_label).float().sum().item()
                    success_sum += attack_flag

                    if attack_flag > 0.5:
                        print('Attack Success')
                    else:
                        print('Attack Failure')

                    '''
                    samples = img.permute(0, 2, 3, 1).data.cpu().numpy()

                    im = np.repeat(samples[0], 3, axis=2)
                    im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item())
                    print(im_path)
                    print(output)
                    plt.imsave(im_path, im)
                    '''

                if test_sum >= 270:
                    mean_p2 /= 270
                    mean_pmax /= 270
                    break
        
        print('Mean L2 Perturbation: %.2f' % mean_p2)
        print('Mean Lmax Perturbation: %.2f' % mean_pmax)
        print('success_sum: %d' % success_sum)
        print('test_sum: %d' % test_sum)
        print('success_rate: %.2f%%' % (100 * success_sum / test_sum))
Beispiel #23
0
 def forward(ctx, scores):
     output = (torch.rand_like(scores) < scores).float()
     return output
Beispiel #24
0
    def select_action(self,
                      agents_inputs_alone,
                      agents_inputs,
                      avail_actions,
                      t_env,
                      test_mode=False):

        # Assuming agent_inputs is a batch of Q-Values for each agent bav
        self.epsilon = self.schedule.eval(t_env)

        if test_mode:
            # Greedy action selection only
            self.epsilon = 0.0

        if self.e_mode == "negative_sample":
            # mask actions that are excluded from selection
            masked_q_values_alone = -agents_inputs_alone.clone()
            masked_q_values_alone[avail_actions == 0.0] = -float(
                "inf")  # should never be selected!
            # mask actions that are excluded from selection
            masked_q_values = agents_inputs.clone()
            masked_q_values[avail_actions == 0.0] = -float(
                "inf")  # should never be selected!
        elif self.e_mode == "exclude_max":
            # mask actions that are excluded from selection
            masked_q_values_alone = agents_inputs_alone.clone()
            masked_q_values_alone[avail_actions == 0.0] = -float(
                "inf")  # should never be selected!
            # mask actions that are excluded from selection
            masked_q_values = agents_inputs.clone()
            masked_q_values[avail_actions == 0.0] = -float(
                "inf")  # should never be selected!

            # Get rid off the top value
            masked_q_values_alone_max = th.argmax(masked_q_values_alone,
                                                  dim=-1,
                                                  keepdim=True)
            masked_q_values_alone_max_oh = th.zeros(
                masked_q_values_alone.shape).cuda()
            masked_q_values_alone_max_oh.scatter_(-1,
                                                  masked_q_values_alone_max, 1)
            masked_q_values_alone[masked_q_values_alone_max_oh == 1] = -1
            masked_q_values_alone[masked_q_values_alone_max_oh == 0] = 0
            masked_q_values_alone = masked_q_values_alone * (th.sum(
                avail_actions, dim=-1, keepdim=True) != 1)
            masked_q_values_alone[masked_q_values_alone == -1] = -float("inf")
            masked_q_values_alone[avail_actions == 0.0] = -float("inf")

        random_numbers = th.rand_like(agents_inputs[:, :, 0])
        pick_random = (random_numbers < self.epsilon).long()
        random_actions = Categorical(avail_actions.float()).sample().long()
        #TODO: these numbers are fixed now
        if t_env > 1000000:
            random_numbers = th.rand_like(agents_inputs[:, :, 0])
            pick_alone = (random_numbers < self.args.e_prob).long()
            alone_actions = Categorical(
                logits=masked_q_values_alone.float()).sample().long()
            final_random_actions = pick_alone * alone_actions + (
                1 - pick_alone) * random_actions
        else:
            final_random_actions = random_actions

        picked_actions = pick_random * final_random_actions + (
            1 - pick_random) * masked_q_values.max(dim=2)[1]
        return picked_actions
Beispiel #25
0
    def forward(ctx, scores):
        output = (torch.rand_like(scores) < scores).float()
        ctx.save_for_backward(output)

        return output
def add_noise_to_voxel(voxel, noise_std=1.0, noise_fraction=0.1):
    noise = noise_std * torch.randn_like(voxel)  # mean = 0, std = noise_std
    if noise_fraction < 1.0:
        mask = torch.rand_like(voxel) >= noise_fraction
        noise.masked_fill_(mask, 0)
    return voxel + noise
Beispiel #27
0
    path = '/media/jojorge/NTFS/yoga/109_2/DeepLearning/DeepLearning_NYCU/lab7_GAN_NF/task_2/CelebA-HQ-img/'
    transform = transforms.Compose([
        transforms.Resize(args.img_size),
        transforms.CenterCrop(args.img_size),
        transforms.ToTensor(),
    ])
    n_bins = 2.0**args.n_bits
    img1 = Image.open(path + '12.jpg')
    img2 = Image.open(path + '13.jpg')
    img1 = preprocess(img1, transform, n_bins)
    img2 = preprocess(img2, transform, n_bins)
    images = [img1, img2]
    z_list = []
    for image in images:
        log_p, logdet, z = model(image + torch.rand_like(image) / n_bins)
        z_list.append(z)

    z1 = z_list[0]
    z2 = z_list[1]
    all_rate = [0.1, 0.3, 0.5, 0.7, 0.9]
    img_concat = img1
    for rate in all_rate:
        interpolate_z = []
        for i in range(len(z1)):
            interpolate_z.append(torch.lerp(z1[i], z2[i], rate))
        with torch.no_grad():
            reconstruct_img = model_single.reverse(interpolate_z).cpu().data
        img_concat = torch.cat((img_concat, reconstruct_img), dim=0)

    img_concat = torch.cat((img_concat, img2), dim=0)
Beispiel #28
0
def b(seed, k):
    b = torch.empty(100, 10, dtype=k.dtype).normal_(-1, 1)
    b /= b.norm(dim=-1, keepdim=True)
    b *= (torch.rand_like(k) * k) ** 0.5
    return lorentz.math.project(b, k=k)
Beispiel #29
0
 def test_rand_like(self):
     shape = (5, 1, 1)
     x = torch.rand_like(torch.zeros(shape, device=xm.xla_device()))
     self.assertEqual(x.device.type, 'xla')
def PoissonGen(inp, rescale_fac=2.0):
    rand_inp = torch.rand_like(inp).cuda()
    return torch.mul(
        torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(),
        torch.sign(inp))
    def forward(self,
                x,
                scaling_params=None,
                return_samples=False,
                output_noise=True,
                resample_output_noise=True,
                sampling_temperature=0.1,
                seed=None,
                **kwargs):
        D = int(self.output_dims)
        nD = D * self.n_components
        if x.dim() == 1:
            x = x.unsqueeze(0)
        outs = x.split(nD, -1)
        if len(outs) == 4:
            mean, log_std, logit_pi, log_temperature = outs
        else:
            mean, log_std, extras = outs
            logit_pi, log_temperature = extras.split(self.n_components, -1)

        # the output shape is [batch_size, output_dimensions, n_components]
        mean = mean.view(-1, D, self.n_components)
        log_std = log_std.view(-1, D, self.n_components)
        temp = 1e-1 + torch.nn.functional.softplus(log_temperature)
        logit_pi = logit_pi / temp

        # scale and center outputs
        if scaling_params is not None and len(scaling_params) > 0:
            if len(scaling_params) == 2:
                my = scaling_params[0].unsqueeze(-1)
                Sy = scaling_params[1].unsqueeze(-1)
                log_std = log_std + Sy.log()
                mean = mean * Sy + my
            else:
                warnings.warn(
                    "Expected scaling_params as tuple or list with 2 elements")

        if return_samples:
            if seed is not None:
                torch.manual_seed(seed)
            if (logit_pi.shape != self.z_pi.shape) or resample_output_noise:
                u = torch.distributions.utils.clamp_probs(
                    torch.rand_like(logit_pi))
                self.z_pi.data = -(-u.log()).log()
            z1 = self.z_pi
            # sample from gumbel softmax
            k_soft = ((log_softmax(logit_pi, -1) + z1) /
                      sampling_temperature).softmax(-1)
            k_idx = k_soft.argmax(-1).view(-1, 1)
            k_hard = torch.zeros_like(k_soft).scatter(1, k_idx, 1)
            # get hard max (but backprop through softmax)
            k = ((k_hard - k_soft).detach() + k_soft)[:, None, :]
            samples = (mean * k).sum(-1)
            if output_noise:
                if (mean[:-1].shape != self.z_pi.shape)\
                        or resample_output_noise:
                    self.z_normal.data = torch.randn(*mean.shape[:-1],
                                                     device=mean.device)
                z2 = self.z_normal
                noise = z2 * (log_std * k).sum(-1).clamp(-15, 15).exp()

                return samples, noise
            return samples
        else:
            return mean, log_std, logit_pi
Beispiel #32
0
 def random_perturb(self, x):
     return (ch.rand_like(x) - 0.5).renorm(p=2,
                                           dim=1,
                                           maxnorm=self.step_size)
Beispiel #33
0
def add_r_(data):
    r = torch.rand_like(data)
    data.add_(r)
Beispiel #34
0
 def __call__(self, loc):
     loc_n = torch.rand_like(loc) * 2 - 1
     if loc_n.is_cuda:
         loc_n = loc_n.cuda()
     return loc_n
Beispiel #35
0
 def random_perturb(self, x):
     return 2 * (ch.rand_like(x) - 0.5) * self.eps
Beispiel #36
0
    def train(self):
        self.optimizer_G = optim.Adam(self.G.parameters(), self.config.g_lr, [self.config.beta1, self.config.beta2])
        self.optimizer_D = optim.Adam(self.D.parameters(), self.config.d_lr, [self.config.beta1, self.config.beta2])
        self.lr_scheduler_G = optim.lr_scheduler.StepLR(self.optimizer_G, step_size=self.config.lr_decay_iters, gamma=0.1)
        self.lr_scheduler_D = optim.lr_scheduler.StepLR(self.optimizer_D, step_size=self.config.lr_decay_iters, gamma=0.1)

        self.load_checkpoint()
        if self.cuda and self.config.ngpu > 1:
            self.G = nn.DataParallel(self.G, device_ids=list(range(self.config.ngpu)))
            self.D = nn.DataParallel(self.D, device_ids=list(range(self.config.ngpu)))

        val_iter = iter(self.data_loader.val_loader)
        x_sample, c_org_sample = next(val_iter)
        x_sample = x_sample.to(self.device)
        c_sample_list = self.create_labels(c_org_sample, self.config.attrs)
        c_sample_list.insert(0, c_org_sample)  # reconstruction

        self.g_lr = self.lr_scheduler_G.get_lr()[0]
        self.d_lr = self.lr_scheduler_D.get_lr()[0]

        data_iter = iter(self.data_loader.train_loader)
        start_time = time.time()
        for i in range(self.current_iteration, self.config.max_iters):
            self.G.train()
            self.D.train()
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # fetch real images and labels
            try:
                x_real, label_org = next(data_iter)
            except:
                data_iter = iter(self.data_loader.train_loader)
                x_real, label_org = next(data_iter)

            # generate target domain labels randomly
            rand_idx = torch.randperm(label_org.size(0))
            label_trg = label_org[rand_idx]

            c_org = label_org.clone()
            c_trg = label_trg.clone()

            x_real = x_real.to(self.device)         # input images
            c_org = c_org.to(self.device)           # original domain labels
            c_trg = c_trg.to(self.device)           # target domain labels
            label_org = label_org.to(self.device)   # labels for computing classification loss
            label_trg = label_trg.to(self.device)   # labels for computing classification loss

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # compute loss with real images
            out_src, out_cls = self.D(x_real)
            d_loss_real = - torch.mean(out_src)
            d_loss_cls = self.classification_loss(out_cls, label_org)

            # compute loss with fake images
            attr_diff = c_trg - c_org
            attr_diff = attr_diff * torch.rand_like(attr_diff) * (2 * self.config.thres_int)
            x_fake = self.G(x_real, attr_diff)
            out_src, out_cls = self.D(x_fake.detach())
            d_loss_fake = torch.mean(out_src)

            # compute loss for gradient penalty
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src, _ = self.D(x_hat)
            d_loss_gp = self.gradient_penalty(out_src, x_hat)

            # backward and optimize
            d_loss_adv = d_loss_real + d_loss_fake + self.config.lambda_gp * d_loss_gp
            d_loss = d_loss_adv + self.config.lambda1 * d_loss_cls
            self.optimizer_D.zero_grad()
            d_loss.backward(retain_graph=True)
            self.optimizer_D.step()

            # summarize
            scalars = {}
            scalars['D/loss'] = d_loss.item()
            scalars['D/loss_adv'] = d_loss_adv.item()
            scalars['D/loss_cls'] = d_loss_cls.item()
            scalars['D/loss_real'] = d_loss_real.item()
            scalars['D/loss_fake'] = d_loss_fake.item()
            scalars['D/loss_gp'] = d_loss_gp.item()

            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #

            if (i + 1) % self.config.n_critic == 0:
                # original-to-target domain
                x_fake = self.G(x_real, attr_diff)
                out_src, out_cls = self.D(x_fake)
                g_loss_adv = - torch.mean(out_src)
                g_loss_cls = self.classification_loss(out_cls, label_trg)

                # target-to-original domain
                x_reconst = self.G(x_fake, c_org - c_org)
                g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))

                # backward and optimize
                g_loss = g_loss_adv + self.config.lambda3 * g_loss_rec + self.config.lambda2 * g_loss_cls
                self.optimizer_G.zero_grad()
                g_loss.backward()
                self.optimizer_G.step()

                # summarize
                scalars['G/loss'] = g_loss.item()
                scalars['G/loss_adv'] = g_loss_adv.item()
                scalars['G/loss_cls'] = g_loss_cls.item()
                scalars['G/loss_rec'] = g_loss_rec.item()

            self.current_iteration += 1

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            if self.current_iteration % self.config.summary_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                print('Elapsed [{}], Iteration [{}/{}]'.format(et, self.current_iteration, self.config.max_iters))
                for tag, value in scalars.items():
                    self.writer.add_scalar(tag, value, self.current_iteration)

            if self.current_iteration % self.config.sample_step == 0:
                self.G.eval()
                with torch.no_grad():
                    x_sample = x_sample.to(self.device)
                    x_fake_list = [x_sample]
                    for c_trg_sample in c_sample_list:
                        attr_diff = c_trg_sample.to(self.device) - c_org_sample.to(self.device)
                        attr_diff = attr_diff * self.config.thres_int
                        x_fake_list.append(self.G(x_sample, attr_diff.to(self.device)))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    self.writer.add_image('sample', make_grid(self.denorm(x_concat.data.cpu()), nrow=1),
                                          self.current_iteration)
                    save_image(self.denorm(x_concat.data.cpu()),
                               os.path.join(self.config.sample_dir, 'sample_{}.jpg'.format(self.current_iteration)),
                               nrow=1, padding=0)

            if self.current_iteration % self.config.checkpoint_step == 0:
                self.save_checkpoint()

            self.lr_scheduler_G.step()
            self.lr_scheduler_D.step()
    def train(self, replay_buffer, batch_size=100):
        """Train and update actor and critic networks
            Args:
                replay_buffer (ReplayBuffer): buffer for experience replay
                batch_size(int): batch size to sample from replay buffer
            Return:
                actor_loss (float): loss from actor network
                critic_loss (float): loss from critic network
        """
        self.total_it += 1
        # Sample replay buffer
        state, next_state, action, reward, done = replay_buffer.sample(
            batch_size)

        state = torch.from_numpy(
            np.asarray([np.array(i.item().values()) for i in state]))

        next_state = np.asarray(
            [np.array(i.item().values()) for i in next_state])
        reward = torch.as_tensor(reward, dtype=torch.float32)
        done = torch.as_tensor(done, dtype=torch.float32)

        with torch.no_grad():
            # select an action according to the policy an add clipped noise
            # need to select set of actions
            noise = (torch.rand_like(torch.from_numpy(action)) *
                     self.policy_noise).clamp(-self.noise_clip,
                                              self.noise_clip)

            next_action = (self.actor_target(
                torch.tensor(next_state, dtype=torch.float32)) +
                           torch.tensor(noise, dtype=torch.float32)).clamp(
                               self.max_action[0], self.max_action[2])
            # next_action_d =torch.as_tensor(next_action, dtype=torch.double)
            # Compute the target Q value
            target_Q1, target_Q2 = self.critic(state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + done * self.discount * target_Q

        # update action datatype, can't do earlier, use np.array earlier
        action = torch.as_tensor(action, dtype=torch.float32)

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)

        # compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q[:1, :].transpose(
            0, 1)) + F.mse_loss(current_Q2, target_Q[:1, :].transpose(0, 1))

        # optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # delayed policy updates
        if self.total_it % self.policy_freq == 0:
            # compute the actor loss
            actor_loss = -self.critic.get_q(state, self.actor(state)).mean()

            # optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(),
                                           self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(),
                                           self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)