Beispiel #1
0
def freeze_rng_state():
    rng_state = torch.get_rng_state()
    if torch.cuda.is_available():
        cuda_rng_state = torch.cuda.get_rng_state()
    yield
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(cuda_rng_state)
    torch.set_rng_state(rng_state)
Beispiel #2
0
	def forward(self,seed):
		#Store rng
		rng_cpu=torch.get_rng_state();
		rng_gpu=torch.cuda.get_rng_state();
		torch.manual_seed(seed);
		torch.cuda.manual_seed(seed);
		mask=[];
		for param in self.params:
			mask.append(Variable(param.data.clone().bernoulli_(self.p)));
		#Recover rng
		torch.set_rng_state(rng_cpu);
		torch.cuda.set_rng_state(rng_gpu);
		#Compute output	
		out=[];
		for i,param in enumerate(self.params):
			out.append(mask[i]*param);
		return out;
Beispiel #3
0
	def noise(self,seed=None):
		if seed is None:
			eps=[];
			for param in self.mean:
				eps.append(Variable(param.data.clone().normal_(0,1)));
			return eps;
		else:
			rng_cpu=torch.get_rng_state();
			rng_gpu=torch.cuda.get_rng_state();
			torch.manual_seed(seed);
			torch.cuda.manual_seed(seed);
			#generate noise
			eps=[];
			for param in self.mean:
				eps.append(Variable(param.data.clone().normal_(0,1)));
			#Recover rng
			torch.set_rng_state(rng_cpu);
			torch.cuda.set_rng_state(rng_gpu);
			return eps;
Beispiel #4
0
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        # create random image that is consistent with the index id
        rng_state = torch.get_rng_state()
        torch.manual_seed(index + self.random_offset)
        img = torch.randn(*self.image_size)
        target = torch.Tensor(1).random_(0, self.num_classes)[0]
        torch.set_rng_state(rng_state)

        # convert to PIL Image
        img = transforms.ToPILImage()(img)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target
Beispiel #5
0
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"):
    """
    Forks the RNG, so that when you return, the RNG is reset
    to the state that it was previously in.

    Arguments:
        devices (iterable of CUDA IDs): CUDA devices for which to fork
            the RNG.  CPU RNG state is always forked.  By default, fork_rng operates
            on all devices, but will emit a warning if your machine has a lot
            of devices, since this function will run very slowly in that case.
            If you explicitly specify devices, this warning will be supressed
        enabled (bool): if ``False``, the RNG is not forked.  This is a convenience
            argument for easily disabling the context manager without having
            to reindent your Python code.
    """

    import torch.cuda
    global _fork_rng_warned_already

    # Internal arguments:
    #   _caller: the function which called fork_rng, which the user used
    #   _devices_kw: the devices keyword of _caller

    if not enabled:
        yield
        return

    if devices is None:
        num_devices = torch.cuda.device_count()
        if num_devices > 1 and not _fork_rng_warned_already:
            warnings.warn(
                ("CUDA reports that you have {num_devices} available devices, and you "
                 "have used {caller} without explicitly specifying which devices are being used. "
                 "For safety, we initialize *every* CUDA device by default, which "
                 "can be quite slow if you have a lot of GPUs.  If you know that you are only "
                 "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
                 "or the '{devices_kw}' keyword argument of {caller} with the set of devices "
                 "you are actually using.  For example, if you are using CPU only, "
                 "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
                 "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0].  To initialize "
                 "all devices and suppress this warning, set the '{devices_kw}' keyword argument "
                 "to `range(torch.cuda.device_count())`."
                 ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
            _fork_rng_warned_already = True
        devices = list(range(num_devices))
    else:
        # Protect against user passing us a generator; we need to traverse this
        # multiple times but a generator will be exhausted upon first traversal
        devices = list(devices)

    cpu_rng_state = torch.get_rng_state()
    gpu_rng_states = []
    for device in devices:
        with torch.cuda.device(device):
            gpu_rng_states.append(torch.cuda.get_rng_state())

    try:
        yield
    finally:
        torch.set_rng_state(cpu_rng_state)
        for device, gpu_rng_state in zip(devices, gpu_rng_states):
            with torch.cuda.device(device):
                torch.cuda.set_rng_state(gpu_rng_state)
def validate(epoch, step, iterator_va, num_va):
    # Store random state
    cpu_rng_state_tr = torch.get_rng_state()
    if device.type == "cuda":
        gpu_rng_state_tr = torch.cuda.get_rng_state()

    # Set random stae
    torch.manual_seed(123)

    # ###
    sum_losses_va = OrderedDict([(loss_name, 0) for loss_name in loss_funcs])

    count_all_va = num_va

    # In validation, set netG.eval()
    netG.eval()
    netD.eval()
    netE.eval()
    num_batches_va = len(iterator_va)
    with torch.set_grad_enabled(False):
        for i_batch, batch in enumerate(iterator_va):

            # voc.shape=(bs, feat_dim, num_frames)
            voc = batch
            voc = voc.to(device)
            voc = (voc - mean) / std

            bs, _, nf = voc.size()

            # ### Train generator ###
            z = (torch.zeros(
                (bs, z_dim, int(np.ceil(nf / z_total_scale_factor)))).normal_(
                    0, 1).float().to(device))

            fake_voc = netG(z)

            z_fake = netE(fake_voc)
            z_real = netE(voc)

            gloss = torch.mean(torch.abs(netD(fake_voc) - fake_voc))
            noise_rloss = torch.mean(torch.abs(z_fake - z))
            real_rloss = torch.mean(
                torch.abs(netG(z_real)[..., :nf] - voc[..., :nf]))

            # ### Train discriminator ###
            real_dloss = torch.mean(torch.abs(netD(voc) - voc))
            fake_dloss = torch.mean(
                torch.abs(netD(fake_voc.detach()) - fake_voc.detach()))

            dloss = real_dloss - k * fake_dloss

            # ### Convergence ###
            _, convergence = recorder(real_dloss, fake_dloss, update_k=False)

            # ### Losses ###
            losses = OrderedDict([
                ("G", gloss),
                ("D", dloss),
                ("RealD", real_dloss),
                ("FakeD", fake_dloss),
                ("Convergence", convergence),
                ("NoiseRecon", noise_rloss),
                ("RealRecon", real_rloss),
            ])

            # ### Misc ###
            # count_all_va += bs

            # Accumulate losses
            losses_va = OrderedDict([(loss_name, lo.item())
                                     for loss_name, lo in losses.items()])

            for loss_name, lo in losses_va.items():
                sum_losses_va[loss_name] += lo * bs

            if i_batch % 10 == 0:
                print("{}/{}".format(i_batch, num_batches_va))

            wandb.log({"loss/eval": losses, "epoch": epoch}, step=step)

    mean_losses_va = OrderedDict([(loss_name, slo / count_all_va)
                                  for loss_name, slo in sum_losses_va.items()])

    # Restore rng state
    torch.set_rng_state(cpu_rng_state_tr)
    if device.type == "cuda":
        torch.cuda.set_rng_state(gpu_rng_state_tr)

    return mean_losses_va
Beispiel #7
0
def set_rng_state(state):
    torch.set_rng_state(state["torch_rng_state"])
    if xm is not None:
        xm.set_rng_state(state["xla_rng_state"])
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(state["cuda_rng_state"])
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
    """Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""
    torch.set_rng_state(rng_state_dict["torch"])
    np.random.set_state(rng_state_dict["numpy"])
    version, state, gauss = rng_state_dict["python"]
    python_set_rng_state((version, tuple(state), gauss))
Beispiel #9
0
 def __enter__(self):
     self._fork = torch.random.fork_rng(devices=self.fwd_gpu_devices,
                                        enabled=True)
     self._fork.__enter__()
     torch.set_rng_state(self.fwd_cpu_state)
     set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)
Beispiel #10
0
 def __exit__(self, type, value, traceback):
     torch.set_rng_state(self.old_rng_state)
Beispiel #11
0
def set_rng_state(state):
    torch.set_rng_state(state["torch_rng_state"])
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(state["cuda_rng_state"])
def set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
    """Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""
    torch.set_rng_state(rng_state_dict.get("torch"))
    np.random.set_state(rng_state_dict.get("numpy"))
    python_set_rng_state(rng_state_dict.get("python"))
Beispiel #13
0
    def backward(ctx, *grad_outputs):  # type: ignore
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
            )
        inputs = ctx.inputs
        model_instance = ctx.model_instance

        for i, need_grad in enumerate(ctx.grad_requirements):
            inputs[i].requires_grad = need_grad

        all_grads = [grad_outputs]

        for model_shard, activation in zip(
                reversed(model_instance.model_slices),
                reversed(model_instance._activations[:-1])):
            # Move the model shard to the device.
            model_shard.backward_load()
            # Store the BW pass state.
            bwd_rng_state = torch.get_rng_state()

            # TODO(anj-s): Why detach inputs?
            activation = torch.utils.checkpoint.detach_variable(activation)
            # Get the last gradient calculation.
            final_grads = all_grads[-1]

            if isinstance(activation, torch.Tensor):
                activation = (activation, )
            if isinstance(final_grads, torch.Tensor):
                final_grads = (final_grads, )
            # Iterate through all the inputs/outputs of a shard (there could be multiple).
            chunked_grad_list: List[Any] = []
            # Chunk the activation and grad based on the number of microbatches that are set.
            for chunked_activation, chunked_grad in zip(
                    torch.chunk(
                        *activation,
                        model_instance._num_microbatches),  # type: ignore
                    torch.chunk(
                        *final_grads,
                        model_instance._num_microbatches),  # type: ignore
            ):
                # Set the states to what it used to be before the forward pass.
                torch.set_rng_state(ctx.fwd_rng_state)

                if isinstance(chunked_activation, torch.Tensor):
                    chunked_activation = (chunked_activation, )  # type: ignore
                if isinstance(chunked_grad, torch.Tensor):
                    chunked_grad = (chunked_grad, )  # type: ignore

                # Since we need a grad value of a non leaf element we need to set these properties.
                for a in chunked_activation:
                    if a.dtype == torch.long:
                        continue
                    a.requires_grad = True
                    a.retain_grad()

                with torch.enable_grad():
                    # calculate the output of the last shard wrt to the stored activation at the slice boundary.
                    outputs = model_shard(*chunked_activation)

                # Set the states back to what it was at the start of this function.
                torch.set_rng_state(bwd_rng_state)
                torch.autograd.backward(outputs, chunked_grad)
                intermediate_grads = []
                for a in chunked_activation:
                    if a.grad is not None:
                        intermediate_grads.append(a.grad)
                if None not in intermediate_grads:
                    chunked_grad_list += intermediate_grads
            if chunked_grad_list:
                # Append the list of grads to the all_grads list and this should be on the CPU.
                all_grads.append(
                    torch.cat(chunked_grad_list).squeeze(-1))  # type: ignore
            # TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None?
            # Move the shard back to the CPU.
            model_shard.backward_drop()
        detached_inputs = model_instance._activations[0]
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in detached_inputs)
        return (None, None) + grads
Beispiel #14
0
def train(classifier: torch.nn.Module,
          x: torch.Tensor,
          y: torch.Tensor,
          test_x: torch.Tensor = torch.Tensor(),
          test_y: torch.Tensor = torch.Tensor(),
          batch_size: int = 16,
          num_epochs: int = 2,
          run_device: str = "cpu",
          learning_rate: float = 0.001,
          beta_1: float = 0.9,
          beta_2: float = 0.999,
          random_state: torch.ByteTensor = torch.get_rng_state().clone(),
          verbose: bool = False) -> Tuple[torch.nn.Module, torch.ByteTensor]:
    """
    Function to train classifiers and save the trained classifiers.

    Parameters
    ----------
    classifier: torch.nn.Module
    x: torch.Tensor
    y: torch.Tensor
    test_x: torch.Tensor
    test_y: torch.Tensor
    batch_size: int
    num_epochs: int
    run_device: str
    learning_rate: float
    beta_1: float
    beta_2: float
    random_state: torch.ByteTensor
    verbose: bool

    Returns
    -------
    Tuple[torch.nn.Module, torch.Tensor]

    """
    assert isinstance(classifier, torch.nn.Module)
    assert isinstance(x, torch.Tensor)
    assert isinstance(y, torch.Tensor)
    assert isinstance(test_x, torch.Tensor)
    assert isinstance(test_y, torch.Tensor)
    assert isinstance(batch_size, int) and (batch_size > 0)
    assert isinstance(num_epochs, int) and (num_epochs > 0)
    assert isinstance(run_device, str) and (run_device.lower()
                                            in ["cpu", "cuda"])
    assert isinstance(learning_rate, float) and (learning_rate > 0.0)
    assert isinstance(beta_1, float) and (0.0 <= beta_1 < 1.0)
    assert isinstance(beta_2, float) and (0.0 <= beta_2 < 1.0)
    assert isinstance(random_state, torch.ByteTensor)
    assert isinstance(verbose, bool)

    # Set the seed for generating random numbers.
    random_state_previous: torch.ByteTensor = torch.get_rng_state().clone()
    torch.set_rng_state(random_state)

    # Set the classifier.
    classifier_train: torch.nn.Module = copy.deepcopy(classifier.cpu())
    run_device_train: str = run_device.lower()
    if run_device_train == "cuda":
        assert torch.cuda.is_available()
        classifier_train = classifier_train.cuda()
        if torch.cuda.device_count() > 1:
            num_gpus: int = torch.cuda.device_count()
            classifier_train = torch.nn.DataParallel(classifier_train,
                                                     device_ids=list(
                                                         range(0, num_gpus)))

    # Set a criterion and optimizer.
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=classifier_train.parameters(),
                                 lr=learning_rate,
                                 betas=(beta_1, beta_2))

    # Covert PyTorch's Tensor to TensorDataset.
    x_train, y_train = x.clone(), y.clone()
    dataset_train = torch.utils.data.TensorDataset(x_train, y_train)
    dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                   batch_size=batch_size,
                                                   num_workers=0,
                                                   shuffle=True)

    has_test: bool = False
    if (test_x.size(0) > 0) and (test_y.size(0) > 0):
        x_test, y_test = test_x.clone(), test_y.clone()
        dataset_test = torch.utils.data.TensorDataset(x_test, y_test)
        dataloader_test = torch.utils.data.DataLoader(dataset_test,
                                                      batch_size=batch_size,
                                                      num_workers=0,
                                                      shuffle=False)
        has_test = True

    # Initialize the early_stopping object.
    early_stopping: _EarlyStopping = _EarlyStopping(patience=10,
                                                    delta=0.0,
                                                    verbose=False)

    log_template: str = "[{0}/{1}] Loss: {2:.4f}, Time: {3:.2f}s"
    log_template_test: str = "[{0}/{1}] Loss (Train): {2:.4f}, Loss (Test): {3:.4f}, Time: {4:.2f}s"

    list_loss: list = list()
    list_loss_test: list = list()

    # Train the classifiers.
    classifier_train.train()
    for epoch in range(1, num_epochs + 1):
        start_time: float = time.time()
        for (_, batch) in enumerate(dataloader_train, 0):
            batch_x, batch_y = batch
            if run_device_train == "cuda":
                batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

            optimizer.zero_grad()
            output: torch.Tensor = classifier_train(batch_x)
            loss: torch.Tensor = criterion(output, batch_y)
            loss.backward()
            optimizer.step()

            list_loss.append(loss.detach().cpu().item())
        end_time: float = time.time()

        if has_test:
            classifier_train.eval()
            for (_, batch) in enumerate(dataloader_test, 0):
                batch_x, batch_y = batch
                if run_device_train == "cuda":
                    batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

                output = classifier_train(batch_x)
                loss = criterion(output, batch_y)

                list_loss_test.append(loss.detach().cpu().item())
            classifier_train.train()

            early_stopping(loss=np.mean(list_loss_test),
                           model=classifier_train)
        else:
            early_stopping(loss=np.mean(list_loss), model=classifier_train)

        if verbose:
            if has_test:
                print(
                    log_template_test.format(epoch, num_epochs,
                                             np.mean(list_loss),
                                             np.mean(list_loss_test),
                                             end_time - start_time))
            else:
                print(
                    log_template.format(epoch, num_epochs, np.mean(list_loss),
                                        end_time - start_time))

        if early_stopping.early_stop:
            state_dict, rng_state = early_stopping.get_best_model()
            classifier_train.load_state_dict(state_dict)
            torch.set_rng_state(rng_state)
            break

    if isinstance(classifier_train, torch.nn.DataParallel):
        classifier_train = classifier_train.module

    random_state_after: torch.ByteTensor = torch.get_rng_state().clone()
    torch.set_rng_state(random_state_previous)

    return classifier_train.cpu(), random_state_after.clone()
Beispiel #15
0
def predict(
    classifier: torch.nn.Module,
    x: torch.Tensor,
    run_device: str = "cpu",
    random_state: torch.ByteTensor = torch.get_rng_state().clone()
) -> np.ndarray:
    """
    Function to evaluate the trained classifiers.

    Parameters
    ----------
    classifier: torch.nn.Module
    x: torch.Tensor
    run_device: str
    random_state: torch.ByteTensor

    Returns
    -------
    numpy.ndarray

    """
    assert isinstance(classifier, torch.nn.Module)
    assert isinstance(x, torch.Tensor)
    assert isinstance(run_device, str) and (run_device.lower()
                                            in ["cpu", "cuda"])
    assert isinstance(random_state, torch.ByteTensor)

    # Set the seed for generating random numbers.
    random_state_previous: torch.ByteTensor = torch.get_rng_state().clone()
    torch.set_rng_state(random_state)

    # Set the classifiers.
    classifier_predict: torch.nn.Module = copy.deepcopy(classifier.cpu())
    run_device_predict: str = run_device.lower()
    if run_device_predict == "cuda":
        assert torch.cuda.is_available()
        classifier_predict = classifier_predict.cuda()
        if torch.cuda.device_count() > 1:
            num_gpus: int = torch.cuda.device_count()
            classifier_predict = torch.nn.DataParallel(classifier_predict,
                                                       device_ids=list(
                                                           range(0, num_gpus)))

    # Covert PyTorch's Tensor to TensorDataset.
    x_predict: torch.Tensor = x.clone()
    dataset = torch.utils.data.TensorDataset(x_predict)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=x_predict.size(0),
                                             num_workers=0,
                                             shuffle=False)

    batch_x: torch.Tensor = next(iter(dataloader))[0]
    if run_device_predict == "cuda":
        batch_x = batch_x.cuda()

    classifier_predict.eval()
    with torch.no_grad():
        output: torch.Tensor = classifier_predict(batch_x)
        predict_y: torch.Tensor = output.detach().cpu().argmax(dim=1,
                                                               keepdim=True)

    torch.set_rng_state(random_state_previous)

    return predict_y.numpy()
Beispiel #16
0
def rise(model,
         input,
         target=None,
         seed=0,
         num_masks=8000,
         num_cells=7,
         filter_masks=None,
         batch_size=32,
         p=0.5,
         resize=False,
         resize_mode='bilinear'):
    r"""RISE.

    Args:
        model (:class:`torch.nn.Module`): a model.
        input (:class:`torch.Tensor`): input tensor.
        seed (int, optional): manual seed used to generate random numbers.
            Default: ``0``.
        num_masks (int, optional): number of RISE random masks to use.
            Default: ``8000``.
        num_cells (int, optional): number of cells for one spatial dimension
            in low-res RISE random mask. Default: ``7``.
        filter_masks (:class:`torch.Tensor`, optional): If given, use the
            provided pre-computed filter masks. Default: ``None``.
        batch_size (int, optional): batch size to use. Default: ``128``.
        p (float, optional): with prob p, a low-res cell is set to 0;
            otherwise, it's 1. Default: ``0.5``.
        resize (bool or tuple of ints, optional): If True, resize saliency map
            to size of :attr:`input`. If False, don't resize. If (width,
            height) tuple, resize to (width, height). Default: ``False``.
        resize_mode (str, optional): If resize is not None, use this mode for
            the resize function. Default: ``'bilinear'``.

    Returns:
        :class:`torch.Tensor`: RISE saliency map.
    """
    with torch.no_grad():
        # Get device of input (i.e., GPU).
        dev = input.device

        # Initialize saliency mask and mask normalization term.
        input_shape = input.shape
        saliency_shape = list(input_shape)

        height = input_shape[2]
        width = input_shape[3]

        out = model(input)
        num_classes = out.shape[1]

        saliency_shape[1] = num_classes
        saliency = torch.zeros(saliency_shape, device=dev)

        # Number of spatial dimensions.
        nsd = len(input.shape) - 2
        assert nsd == 2

        # Spatial size of low-res grid cell.
        cell_size = tuple(
            [int(np.ceil(s / num_cells)) for s in input_shape[2:]])

        # Spatial size of upsampled mask with buffer (input size + cell size).
        up_size = tuple(
            [input_shape[2 + i] + cell_size[i] for i in range(nsd)])

        # Save current random number generator state.
        state = torch.get_rng_state()

        # Set seed.
        torch.manual_seed(seed)

        if filter_masks is not None:
            assert len(filter_masks) == num_masks

        num_chunks = (num_masks + batch_size - 1) // batch_size
        for chunk in range(num_chunks):
            # Generate RISE random masks on the fly.
            mask_bs = min(num_masks - batch_size * chunk, batch_size)

            if filter_masks is None:
                # Generate low-res, random binary masks.
                grid = (torch.rand(
                    mask_bs, 1, *((num_cells, ) * nsd), device=dev) <
                        p).float()

                # Upsample low-res masks to input shape + buffer.
                masks_up = _upsample_reflect(grid, up_size)

                # Save final RISE masks with random shift.
                masks = torch.empty(mask_bs, 1, *input_shape[2:], device=dev)
                shift_x = torch.randint(0,
                                        cell_size[0], (mask_bs, ),
                                        device='cpu')
                shift_y = torch.randint(0,
                                        cell_size[1], (mask_bs, ),
                                        device='cpu')
                for i in range(mask_bs):
                    masks[i] = masks_up[i, :, shift_x[i]:shift_x[i] + height,
                                        shift_y[i]:shift_y[i] + width]
            else:
                masks = filter_masks[chunk * batch_size:chunk * batch_size +
                                     mask_bs]

            # Accumulate saliency mask.
            for i, inp in enumerate(input):
                out = torch.sigmoid(model(inp.unsqueeze(0) * masks))
                if len(out.shape) == 4:
                    # TODO: Consider handling FC outputs more flexibly.
                    assert out.shape[2] == 1
                    assert out.shape[3] == 1
                    out = out[:, :, 0, 0]
                sal = torch.matmul(out.data.transpose(0, 1),
                                   masks.view(mask_bs, height * width))
                sal = sal.view((num_classes, height, width))
                saliency[i] = saliency[i] + sal

        # Normalize saliency mask.
        saliency /= num_masks

        # Restore original random number generator state.
        torch.set_rng_state(state)

        # Resize saliency mask if needed.
        saliency = resize_saliency(input, saliency, resize, mode=resize_mode)
        return saliency
Beispiel #17
0
 def setUp(self) -> None:
     torch.set_rng_state(torch.manual_seed(42).get_state())
Beispiel #18
0
def calculate_f1_score(individual: np.ndarray,
                       x: torch.Tensor,
                       y: torch.Tensor,
                       list_sample_by_label: list,
                       random_state: torch.ByteTensor = torch.get_rng_state().clone(),
                       **kwargs) -> Tuple:
    """
    Function to calculate fitness.

    Parameters
    ----------
    individual: np.ndarray
    x: torch.Tensor
    y: torch.Tensor
    list_sample_by_label: list
    random_state: torch.ByteTensor

    Returns
    -------
    Tuple

    """
    assert isinstance(individual, np.ndarray)
    assert isinstance(x, torch.Tensor)
    assert isinstance(y, torch.Tensor)
    assert isinstance(list_sample_by_label, list)
    assert isinstance(random_state, torch.ByteTensor)

    classifier_num_hidden_layers: int = 1
    if "classifier_num_hidden_layers" in kwargs:
        assert isinstance(kwargs["classifier_num_hidden_layers"], int) and (kwargs["classifier_num_hidden_layers"] > 0)
        classifier_num_hidden_layers = kwargs["classifier_num_hidden_layers"]

    # Parameters for training.
    classifier_batch_size: int = 16
    if "classifier_batch_size" in kwargs:
        assert isinstance(kwargs["classifier_batch_size"], int) and (kwargs["classifier_batch_size"] > 0)
        classifier_batch_size = kwargs["classifier_batch_size"]
    classifier_num_epochs: int = 2
    if "classifier_num_epochs" in kwargs:
        assert isinstance(kwargs["classifier_num_epochs"], int) and (kwargs["classifier_num_epochs"] > 0)
        classifier_num_epochs = kwargs["classifier_num_epochs"]

    # Check the running device for PyTorch.
    classifier_run_device: str = "cpu"
    if "classifier_run_device" in kwargs:
        assert isinstance(kwargs["classifier_run_device"], str)
        assert str(kwargs["classifier_run_device"]).lower() in ["cpu", "cuda"]
        classifier_run_device = str(kwargs["classifier_run_device"]).lower()

    # Parameters for Adam optimizer.
    classifier_learning_rate: float = 0.001
    if "classifier_learning_rate" in kwargs:
        assert isinstance(kwargs["classifier_learning_rate"], float) and (kwargs["classifier_learning_rate"] > 0.0)
        classifier_learning_rate = kwargs["classifier_learning_rate"]
    classifier_beta_1: float = 0.9
    if "classifier_beta_1" in kwargs:
        assert isinstance(kwargs["classifier_beta_1"], float) and (0.0 <= kwargs["classifier_beta_1"] < 1.0)
        classifier_beta_1 = kwargs["classifier_beta_1"]
    classifier_beta_2: float = 0.999
    if "classifier_beta_2" in kwargs:
        assert isinstance(kwargs["classifier_beta_2"], float) and (0.0 <= kwargs["classifier_beta_2"] < 1.0)
        classifier_beta_2 = kwargs["classifier_beta_2"]

    # Set the seed for generating random numbers.
    torch.set_rng_state(random_state)

    size_features: int = x.size(1)
    size_labels: int = int(y.max().item() - y.min().item()) + 1

    y_stats: dict = Counter(y.numpy())
    list_label: list = list(list_sample_by_label[0].keys())

    train_x: torch.Tensor = x.clone()
    train_y: torch.Tensor = y.clone()
    for (i, ratio_by_label) in enumerate(individual):
        label: int = list_label[i]
        for (j, ratio_by_method) in enumerate(ratio_by_label):
            method: int = j
            if ratio_by_method > 0.0:
                number: int = int(y_stats[label] * ratio_by_method)
                new_x: torch.Tensor = torch.as_tensor(list_sample_by_label[method][label][0][:number],
                                                      dtype=torch.float)
                new_y: torch.Tensor = torch.as_tensor(list_sample_by_label[method][label][1][:number],
                                                      dtype=torch.long)

                train_x = torch.cat([train_x, new_x], dim=0)
                train_y = torch.cat([train_y, new_y], dim=0)

    test_x: torch.Tensor = x.clone()
    test_y: torch.Tensor = y.clone()

    classifier: torch.nn.Module = DNNClassifier(size_features=size_features,
                                                num_hidden_layers=classifier_num_hidden_layers,
                                                size_labels=size_labels)
    trained_classifier, trained_random_state = classifier_train(classifier=classifier,
                                                                x=train_x,
                                                                y=train_y,
                                                                test_x=test_x,
                                                                test_y=test_y,
                                                                batch_size=classifier_batch_size,
                                                                num_epochs=classifier_num_epochs,
                                                                run_device=classifier_run_device,
                                                                learning_rate=classifier_learning_rate,
                                                                beta_1=classifier_beta_1,
                                                                beta_2=classifier_beta_2,
                                                                random_state=torch.get_rng_state().clone(),
                                                                verbose=False)
    f1_score: np.ndarray = classifier_evaluate(classifier=trained_classifier,
                                               x=x,
                                               y=y,
                                               metric="f1_score",
                                               run_device=classifier_run_device,
                                               random_state=trained_random_state,
                                               verbose=False)

    return np.mean(f1_score),
if os.path.exists(modelToUse_2):
    model_2 = discriminatorNet_archi_2()
    model_2.load_state_dict(torch.load(modelToUse_2))
    model_2 = model_2.to(device2)
    print('Loaded previous model for archi 2')
if os.path.exists(modelToUse_3):
    model_3 = discriminatorNet_archi_3()
    model_3.load_state_dict(torch.load(modelToUse_3))
    model_3 = model_3.to(device3)
    print('Loaded previous model for archi 3')

# set the torch random state to what it last was
if os.path.exists(f'{out_path_1}randomState.txt'):
    random_state = torch.from_numpy(
        np.loadtxt(f'{out_path_1}randomState.txt').astype('uint8'))
    torch.set_rng_state(random_state)
    print('Loaded torch random state')

# load the previous training losses
if os.path.exists(out_path_1 +
                  '/allValLoss.txt') and os.path.exists(out_path_1 +
                                                        '/allTrainLoss.txt'):
    allValLoss_tmp_1 = np.loadtxt(out_path_1 + '/allValLoss.txt')
    allTrainLoss_tmp_1 = np.loadtxt(out_path_1 + '/allTrainLoss.txt')

    # populate new array to preserve the epoch number (we might re-run with a higher epoch number to continue training)
    allTrainLoss_1[0:allTrainLoss_tmp_1.size] = allTrainLoss_tmp_1
    allValLoss_1[0:allValLoss_tmp_1.size] = allValLoss_tmp_1

    print('Loaded previous loss history for archi 1')
Beispiel #20
0
def load_checkpoint(model, optimizer, lr_scheduler, args):
    """Load a model checkpoint."""

    iteration, release, success = get_checkpoint_iteration(args)

    if not success:
        return 0

    if args.deepspeed:

        checkpoint_name, sd = model.load_checkpoint(
            args.load,
            iteration,
            load_module_strict=False,
            load_optimizer_states=False,
            load_lr_scheduler_states=False)

        if checkpoint_name is None:
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")
            return iteration

    else:

        # Checkpoint.
        checkpoint_name = get_checkpoint_name(args.load, iteration, release)

        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        # Load the checkpoint.
        sd = torch.load(checkpoint_name, map_location='cpu')

        if isinstance(model, torchDDP):
            model = model.module

        # Model.
        try:
            model.load_state_dict(sd['model'])
        except KeyError:
            print_rank_0('A metadata file exists but unable to load model '
                         'from checkpoint {}, exiting'.format(checkpoint_name))
            exit()

        # Optimizer.
        if not release and not args.finetune and not args.no_load_optim:
            try:
                if optimizer is not None:
                    optimizer.load_state_dict(sd['optimizer'])
                if lr_scheduler is not None:
                    lr_scheduler.load_state_dict(sd['lr_scheduler'])
            except KeyError:
                print_rank_0(
                    'Unable to load optimizer from checkpoint {}, exiting. '
                    'Specify --no-load-optim or --finetune to prevent '
                    'attempting to load the optimizer '
                    'state.'.format(checkpoint_name))
                exit()

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0(
                    'A metadata file exists but Unable to load iteration '
                    ' from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0(
                'Unable to load optimizer from checkpoint {}, exiting. '
                'Specify --no-load-optim or --finetune to prevent '
                'attempting to load the optimizer '
                'state.'.format(checkpoint_name))
            exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Beispiel #21
0
def train_AE():

    # define optimizer
    params = list(net_encoder.parameters()) + list(net_decoder.parameters())
    optimizer = torch.optim.Adam(params,
                                 lr=base_lr,
                                 betas=(0.5, 0.999),
                                 weight_decay=args.weight_dacay)
    # optimizer = torch.optim.SGD(params, lr = base_lr, momentum=0.9)

    # criterion
    criterion = nn.MSELoss()

    if resume_epoch > 0:
        print("Loading ckpt to resume training AE >>>")
        ckpt_fullpath = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(
            resume_epoch, lambda_sparsity)
        checkpoint = torch.load(ckpt_fullpath)
        net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
        net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
        gen_iterations = checkpoint['gen_iterations']
    else:
        gen_iterations = 0

    start_time = timeit.default_timer()
    for epoch in range(resume_epoch, epochs):

        adjust_learning_rate(epoch, epochs, optimizer, base_lr,
                             lr_decay_epochs, lr_decay_factor)

        train_loss = 0

        for batch_idx, batch_real_images in enumerate(trainloader):

            net_encoder.train()
            net_decoder.train()

            batch_size_curr = batch_real_images.shape[0]

            batch_real_images = batch_real_images.type(torch.float).cuda()

            batch_features = net_encoder(batch_real_images)
            batch_recons_images = net_decoder(batch_features)
            '''
            based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
            '''
            loss = criterion(
                batch_recons_images,
                batch_real_images) + lambda_sparsity * batch_features.mean()

            #backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.cpu().item()

            gen_iterations += 1

            if gen_iterations % 100 == 0:
                n_row = min(10, int(np.sqrt(batch_size_curr)))
                net_encoder.eval()
                net_decoder.eval()
                with torch.no_grad():
                    batch_recons_images = net_decoder(
                        net_encoder(batch_real_images[0:n_row**2]))
                    batch_recons_images = batch_recons_images.detach().cpu()
                save_image(batch_recons_images.data,
                           save_AE_images_in_train_folder +
                           '/{}.png'.format(gen_iterations),
                           nrow=n_row,
                           normalize=True)

            if gen_iterations % 20 == 0:
                print(
                    "AE+lambda{}: [step {}] [epoch {}/{}] [train loss {}] [Time {}]"
                    .format(lambda_sparsity, gen_iterations, epoch + 1, epochs,
                            train_loss / (batch_idx + 1),
                            timeit.default_timer() - start_time))
        # end for batch_idx

        if (epoch + 1) % args.save_ckpt_freq == 0:
            save_file = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(
                epoch + 1, lambda_sparsity)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save(
                {
                    'gen_iterations': gen_iterations,
                    'net_encoder_state_dict': net_encoder.state_dict(),
                    'net_decoder_state_dict': net_decoder.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'rng_state': torch.get_rng_state()
                }, save_file)
    #end for epoch

    return net_encoder, net_decoder
Beispiel #22
0
    net = torch.nn.DataParallel(net)
    print('Using', torch.cuda.device_count(), 'GPUs.')
    cudnn.benchmark = True
    print('Using CUDA..')

# Model
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    # checkpoint_file = './checkpoint/ckpt.t7.' + args.sess + '_' + str(args.seed)
    checkpoint = torch.load(args.resume)
    net.load_state_dict(checkpoint['net'].state_dict())
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch'] + 1
    torch.set_rng_state(checkpoint['rng_state'])

result_folder = './results/'
if not os.path.exists(result_folder):
    os.makedirs(result_folder)

if 'nobn' in args.arch or 'fixup' in args.arch or args.no_bn and 'resnet' in args.arch:
    parameters_bias = [p[1] for p in net.named_parameters() if 'bias' in p[0]]
    parameters_scale = [p[1] for p in net.named_parameters() if 'scale' in p[0]]
    parameters_others = [p[1] for p in net.named_parameters() if not ('bias' in p[0] or 'scale' in p[0] or 'autoinit' in p[0])]
    optimizer = optim.SGD(
            [{'params': parameters_bias, 'lr': args.base_lr/10.},
            {'params': parameters_scale, 'lr': args.base_lr/10.},
            {'params': parameters_others}],
            lr=base_learning_rate,
            momentum=0.9,
Beispiel #23
0
def fork_rng(devices=None,
             enabled=True,
             _caller="fork_rng",
             _devices_kw="devices"):
    """
    Forks the RNG, so that when you return, the RNG is reset
    to the state that it was previously in.

    Arguments:
        devices (iterable of CUDA IDs): CUDA devices for which to fork
            the RNG.  CPU RNG state is always forked.  By default, fork_rng operates
            on all devices, but will emit a warning if your machine has a lot
            of devices, since this function will run very slowly in that case.
            If you explicitly specify devices, this warning will be supressed
        enabled (bool): if False, the RNG is not forked.  This is a convenience
            argument for easily disabling the context manager without having
            to reindent your Python code.
    """

    import torch.cuda
    global _fork_rng_warned_already

    # Internal arguments:
    #   _caller: the function which called fork_rng, which the user used
    #   _devices_kw: the devices keyword of _caller

    if not enabled:
        yield
        return

    if devices is None:
        num_devices = torch.cuda.device_count()
        if num_devices > 1 and not _fork_rng_warned_already:
            warnings.warn((
                "CUDA reports that you have {num_devices} available devices, and you "
                "have used {caller} without explicitly specifying which devices are being used. "
                "For safety, we initialize *every* CUDA device by default, which "
                "can be quite slow if you have a lot of GPUs.  If you know that you are only "
                "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
                "or the '{devices_kw}' keyword argument of {caller} with the set of devices "
                "you are actually using.  For example, if you are using CPU only, "
                "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
                "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0].  To initialize "
                "all devices and suppress this warning, set the '{devices_kw}' keyword argument "
                "to `range(torch.cuda.device_count())`.").format(
                    num_devices=num_devices,
                    caller=_caller,
                    devices_kw=_devices_kw))
            _fork_rng_warned_already = True
        devices = list(range(num_devices))
    else:
        # Protect against user passing us a generator; we need to traverse this
        # multiple times but a generator will be exhausted upon first traversal
        devices = list(devices)

    cpu_rng_state = torch.get_rng_state()
    gpu_rng_states = []
    for device in devices:
        with torch.cuda.device(device):
            gpu_rng_states.append(torch.cuda.get_rng_state())

    try:
        yield
    finally:
        torch.set_rng_state(cpu_rng_state)
        for device, gpu_rng_state in zip(devices, gpu_rng_states):
            with torch.cuda.device(device):
                torch.cuda.set_rng_state(gpu_rng_state)
Beispiel #24
0
 def __exit__(self, *args):
     numpy.random.set_state(self.numpy_state)
     torch.set_rng_state(self.torch_state)
Beispiel #25
0



# Model Configuration
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.t7' + args.name + '_'
                            + str(args.seed))
    net = checkpoint['net']
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch'] + 1
    rng_state = checkpoint['rng_state']
    torch.set_rng_state(rng_state)
else:
    print('==> Building model..')
    net = models.__dict__[args.model]()







check_location('./results')
logname = ('./results/log_' + net.__class__.__name__ + '_' + args.name + '_'
           + str(args.seed) + '.csv')
if not os.path.exists(logname):
    with open(logname, 'w') as logfile:
Beispiel #26
0
def load_checkpoint(neox_args,
                    model,
                    optimizer,
                    lr_scheduler,
                    inference=False,
                    iteration=None):
    """Load a model checkpoint and return the iteration."""
    if neox_args.deepspeed:
        load_optim_and_scheduler = (
            not neox_args.no_load_optim
        )  # TODO: These should be configured by separate args
        if neox_args.finetune:
            load_optim_and_scheduler = False
        if iteration is not None:
            tag = f"global_step{iteration}"
        else:
            tag = None
        checkpoint_name, state_dict = model.load_checkpoint(
            neox_args.load,
            load_optimizer_states=load_optim_and_scheduler,
            load_lr_scheduler_states=load_optim_and_scheduler,
            tag=tag,
        )

        if checkpoint_name is None:
            # if an iteration is specified, we want to raise an error here rather than
            # continuing silently, since we are trying to load a specific checkpoint
            if iteration is not None:
                available_checkpoints = sorted([
                    int(i.name.replace("global_step", ""))
                    for i in Path(neox_args.load).glob("global_step*")
                ])
                raise ValueError(
                    f"Unable to load checkpoint for iteration {iteration}. \nAvailable iterations: {pformat(available_checkpoints)}"
                )
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")

            return 0  # iteration 0, if not checkpoint loaded
    else:
        raise ValueError("Must be using deepspeed to use neox")

    # Set iteration.
    if neox_args.finetune:
        iteration = 0
    else:
        iteration = state_dict.get("iteration") or state_dict.get(
            "total_iters"
        )  # total_iters backward compatible with older checkpoints
        if iteration is None:
            raise ValueError(
                f"Unable to load iteration from checkpoint {checkpoint_name} with keys {state_dict.keys()}, exiting"
            )

    # Check arguments.
    if "args" in state_dict:
        checkpoint_args = state_dict["args"]
        check_checkpoint_args(neox_args=neox_args,
                              checkpoint_args=checkpoint_args)
        print_rank_0(
            " > validated currently set args with arguments in the checkpoint ..."
        )
    else:
        print_rank_0(
            " > could not find arguments in the checkpoint for validation...")

    # Check loaded checkpoint with forward pass
    if neox_args.checkpoint_validation_with_forward_pass:
        if "checkpoint_validation_logits" in state_dict:
            check_forward_pass(
                neox_args=neox_args,
                model=model,
                checkpoint_logits=state_dict["checkpoint_validation_logits"],
                inference=inference,
            )
            print_rank_0(
                " > validated loaded checkpoint with forward pass ...")
        else:
            if mpu.get_data_parallel_rank() == 0:
                print(
                    " > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}"
                    .format(checkpoint_name))

    # rng states.
    if not neox_args.finetune and not neox_args.no_load_rng:
        try:
            random.setstate(state_dict["random_rng_state"])
            np.random.set_state(state_dict["np_rng_state"])
            torch.set_rng_state(state_dict["torch_rng_state"])
            torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
            mpu.get_cuda_rng_tracker().set_states(
                state_dict["rng_tracker_states"])
        except KeyError:
            print_rank_0("Unable to load optimizer from checkpoint {}. "
                         "Specify --no-load-rng or --finetune to prevent "
                         "attempting to load the optimizer state, "
                         "exiting ...".format(checkpoint_name))
            sys.exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print("  successfully loaded {}".format(checkpoint_name))

    return iteration
Beispiel #27
0
def manual_seed(seed):
    torch_state = torch.get_rng_state()
    torch.manual_seed(seed)
    yield
    torch.set_rng_state(torch_state)
Beispiel #28
0
def main():
    hp = parse_args()

    # Setup model directories
    model_name = get_model_name(hp)
    model_path = path.join(hp.model_dir, model_name)
    best_model_path = path.join(model_path, 'best_models')
    if not path.exists(model_path):
        os.makedirs(model_path)
    if not path.exists(best_model_path):
        os.makedirs(best_model_path)

    # Set random seed
    torch.manual_seed(hp.seed)

    # Initialize the model
    model = CorefModel(**vars(hp)).cuda()
    sys.stdout.flush()

    # Load data
    logging.info("Loading data")
    train_iter, val_iter, test_iter = CorefDataset.iters(
        hp.data_dir,
        model.encoder,
        batch_size=hp.batch_size,
        eval_batch_size=hp.eval_batch_size,
        train_frac=hp.train_frac)
    logging.info("Data loaded")

    optimizer_tune = None
    if hp.fine_tune:
        # TODO(shtoshni): Fix the parameters stuff
        optimizer_tune = torch.optim.Adam(model.get_core_params(),
                                          lr=hp.lr_tune)
    optimizer = torch.optim.Adam(model.get_other_params(), lr=hp.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           patience=5,
                                                           factor=0.5,
                                                           verbose=True)
    steps_done = 0
    max_f1 = 0
    init_num_stuck_evals = 0
    num_steps = (hp.n_epochs * len(train_iter.data())) // hp.batch_size
    # Quantize the number of training steps to eval steps
    num_steps = int(math.ceil(num_steps / hp.eval_steps)) * hp.eval_steps
    logging.info("Total training steps: %d" % num_steps)

    location = path.join(model_path, "model.pt")
    if path.exists(location):
        logging.info("Loading previous checkpoint")
        checkpoint = torch.load(location)
        model.encoder.weighing_params = checkpoint['weighing_params']
        model.span_net.load_state_dict(checkpoint['span_net'])
        model.label_net.load_state_dict(checkpoint['label_net'])
        # if hp.no_proj:
        #     model.proj_net.load_state_dict(checkpoint['proj_net'])
        if hp.fine_tune:
            model.encoder.load_state_dict(checkpoint['encoder'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        steps_done = checkpoint['steps_done']
        init_num_stuck_evals = checkpoint['num_stuck_evals']
        max_f1 = checkpoint['max_f1']
        torch.set_rng_state(checkpoint['rng_state'])
        logging.info("Steps done: %d, Max F1: %.3f" % (steps_done, max_f1))

    if not hp.eval:
        train(model,
              train_iter,
              val_iter,
              optimizer,
              optimizer_tune,
              scheduler,
              model_path,
              best_model_path,
              init_steps=steps_done,
              max_f1=max_f1,
              eval_steps=hp.eval_steps,
              num_steps=num_steps,
              init_num_stuck_evals=init_num_stuck_evals)

    val_f1, test_f1 = final_eval(hp, best_model_path, val_iter, test_iter)
    perf_dir = path.join(hp.model_dir, "perf")
    if not path.exists(perf_dir):
        os.makedirs(perf_dir)
    if hp.slurm_id:
        perf_file = path.join(perf_dir, hp.slurm_id + ".txt")
    else:
        perf_file = path.join(model_path, "perf.txt")
    with open(perf_file, "w") as f:
        f.write("%s\n" % (model_path))
        f.write("%s\t%.4f\n" % ("Valid", val_f1))
        f.write("%s\t%.4f\n" % ("Test", test_f1))
Beispiel #29
0
def main():
    best_acc1 = 0
    args = parser.parse_args()
    assert args.batch_size % args.effective_bs == 0, "Effective batch size must be a divisor of batch_size"

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    if args.train:
        writer = SummaryWriter()
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                #best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            if args.train:
                try:
                    optimizer.load_state_dict(checkpoint['optimizer'])
                    print("=> loaded optimizer state from checkpoint")
                except:
                    print("=> optimizer state not found in checkpoint")
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    norm_params = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
    normalize = transforms.Normalize(mean=norm_params['mean'],
                                     std=norm_params['std'])
    test_loader = get_test_loader(args, normalize)
    if args.evaluate == 'corrupted':
        corrupted_test_loader = get_test_loader(
            args, normalize,
            lambda img: apply_random_corruption(img, test=True))

    if args.train:
        if args.augment_train_data:
            # as augmodel will be applied before normalization,
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.ToTensor(),
                ]))
            if args.augmentations:
                ops = []
                for aug in args.augmentations:
                    ops.append(augmentations.__dict__[aug])
            else:
                ops = augmentations.standard_augmentations
            # initialize augmodel
            print('Using augmentations ' + str(ops))
            augmodel = AugModel(norm_params=norm_params,
                                augmentations=ops,
                                augmentation_mean=args.augmentation_mean,
                                augmentation_std=args.augmentation_std,
                                min_magnitude=args.min_magnitude,
                                max_magnitude=args.max_magnitude)
            if args.resume and 'augmodel_state_dict' in checkpoint.keys():
                augmodel.load_state_dict(checkpoint['augmodel_state_dict'])
            if args.gpu is not None:
                augmodel = augmodel.cuda(args.gpu)
                augmodel.augmentations[1].enc_to()
                augmodel.augmentations[1].dec_to()
            if 'AdaptiveStyleTransfer' in args.augmentations:
                ast = augmodel.augmentations[1]
                ast.initStyles(args.style_subset, seed=args.seed)
                if args.style_indices:
                    assert args.effective_bs % len(
                        args.style_indices
                    ) == 0, "Number of style indices must be a divisor of effective bs!"
                    ast._PainterByNumbers = torch.utils.data.dataset.Subset(
                        ast._PainterByNumbers, args.style_indices)
                    ast.initStyles(len(args.style_indices), seed=args.seed)
                    factor = args.effective_bs // len(args.style_indices)
                    ast.style_features = ast.style_features.repeat(
                        factor, 1, 1, 1)
            if 'StyleTransfer' in args.augmentations and args.style_subset is not None:
                op = augmodel.augmentations[1]
                assert str(op) == 'StyleTransfer'
                pbn = op._PainterByNumbers
                assert 0 < args.style_subset < len(pbn)
                if args.seed:
                    rng_state = torch.get_rng_state(
                    )  # save the pseudo-random state
                    torch.manual_seed(
                        args.seed
                    )  # set the seed for deterministic dataset splits
                pbn_split, _ = torch.utils.data.dataset.random_split(
                    pbn, [args.style_subset,
                          len(pbn) - args.style_subset])
                if args.seed:
                    torch.set_rng_state(
                        rng_state
                    )  # reset the state for non-deterministic behaviour below
                op._PainterByNumbers = pbn_split
                op.resetStyleLoader(args.effective_bs)
        else:
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.ToTensor(), normalize
                ]))
            augmodel = None

        if args.ho:
            ho_criterion = nn.CrossEntropyLoss().cuda(args.gpu)
            ho_optimizer = torch.optim.Adam(
                [p for p in augmodel.parameters() if p.requires_grad],
                args.ho_lr)
            if args.resume and 'ho_optimizer' in checkpoint.keys():
                try:
                    ho_optimizer.load_state_dict(checkpoint['ho_optimizer'])
                    print("=> loaded optimizer state from checkpoint")
                except:
                    print("=> optimizer state not found in checkpoint")

            # train/val split
            train_size = int(len(train_dataset) * args.train_size)
            if args.seed:
                rng_state = torch.get_rng_state(
                )  # save the pseudo-random state
                torch.manual_seed(
                    args.seed)  # set the seed for deterministic dataset splits
            train_split, val_split = torch.utils.data.dataset.random_split(
                train_dataset,
                [train_size, len(train_dataset) - train_size])
            if args.seed:
                torch.set_rng_state(
                    rng_state
                )  # reset the state for non-deterministic behaviour below
            if args.validation_objective == 'clean':
                val_transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ])
            elif args.validation_objective == 'corrupted':
                val_transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.Lambda(apply_random_corruption),
                    transforms.ToTensor(),
                    normalize,
                ])
            # as the underlying dataset of both splits is the same, this is the only way of having separate transforms for train and val split
            val_dataset = datasets.ImageFolder(traindir,
                                               transform=val_transform)
            val_split.dataset = val_dataset

            train_loader = torch.utils.data.DataLoader(
                train_split,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)

            val_loader = InfiniteDataLoader(val_split,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            num_workers=args.workers,
                                            pin_memory=True,
                                            drop_last=True)
        else:
            if args.path_to_stylized and not args.augment_train_data:
                stylized_imagenet = datasets.ImageFolder(
                    root=traindir,
                    loader=stylized_loader,
                    transform=transforms.Compose(
                        [transforms.ToTensor(), normalize]))
                train_dataset = torch.utils.data.ConcatDataset(
                    [train_dataset, stylized_imagenet])

            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)
            val_loader = None
            ho_criterion = None
            ho_optimizer = None

        # training
        for epoch in range(args.start_epoch, args.epochs):
            if args.decrease_temperature is not None and (
                    epoch - args.start_epoch
            ) % args.decrease_temperature == 0 and not epoch == args.start_epoch:
                augmodel.augmentations[1].temperature /= 2
            if args.increasing_alpha is not None and (
                    epoch - args.start_epoch) % args.increasing_alpha == 0:
                op = augmodel.augmentations[1]
                assert isinstance(op, StyleTransfer)
                current_alpha = op.mu_mag

                ckpt = {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }
                if args.ho:
                    ckpt['augmodel_state_dict'] = augmodel.state_dict()
                    ckpt['ho_optimizer'] = ho_optimizer.state_dict()
                save_checkpoint(ckpt,
                                is_best=False,
                                filename='checkpoint_alpha_%1.3f.pth.tar' %
                                (current_alpha.item()))

                updated_alpha = current_alpha + 0.1
                op.mu_mag = updated_alpha
                print("=> alpha=%1.2f" % (op.mu_mag.item()))
            train(train_loader, val_loader, model, augmodel, criterion,
                  ho_criterion, optimizer, ho_optimizer, epoch, args, writer)
            is_best = False
            # evaluate on validation set
            if epoch % args.print_freq == 0:
                acc1 = validate(test_loader, model, criterion, args)
                writer.add_scalar('Metrics/test_acc', acc1, epoch)
                if args.evaluate == 'corrupted':
                    mpc = validate(corrupted_test_loader, model, criterion,
                                   args)
                    writer.add_scalar('Metrics/test_mpc', mpc, epoch)

                # remember best acc@1 and save checkpoint
                is_best = acc1 > best_acc1
                best_acc1 = max(acc1, best_acc1)

            ckpt = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }

            if args.ho:
                ckpt['augmodel_state_dict'] = augmodel.state_dict()
                ckpt['ho_optimizer'] = ho_optimizer.state_dict()

            save_checkpoint(ckpt, is_best)

    if args.evaluate == 'clean':
        validate(test_loader, model, criterion, args)
    elif args.evaluate == 'corrupted':
        corruptions = ic.get_corruption_names('all')
        severities = [0, 1, 2, 3, 4, 5]
        accuracies = {}
        for corruption in corruptions:
            accuracies[corruption] = {}
            for severity in severities:
                if severity == 0:
                    print('Testing clean')
                    acc = validate(test_loader, model, criterion, args)
                    accuracies[corruption][severity] = torch.squeeze(
                        acc.cpu()).item()
                else:
                    print('Testing %s:%d' % (corruption, severity))
                    corrupted_loader = get_test_loader(
                        args, normalize, lambda x: Image.fromarray(
                            ic.corrupt(np.array(x, dtype=np.uint8),
                                       corruption_name=corruption,
                                       severity=severity)))
                    acc = validate(corrupted_loader, model, criterion, args)
                    accuracies[corruption][severity] = torch.squeeze(
                        acc.cpu()).item()
        if args.train:
            e = args.epochs
        elif args.resume:
            e = args.start_epoch
        pickle.dump(accuracies, open("robustness_epoch_{}.pkl".format(e),
                                     "wb"))
 def tearDown(self):
     if hasattr(self, "rng_state"):
         torch.set_rng_state(self.rng_state)
Beispiel #31
0
def _train(data_train,
           Nminibatch,
           order,
           C,
           rng,
           lr_train,
           debug,
           maxiter,
           maxtime,
           init,
           dftol_stop,
           freltol_stop,
           dn_log,
           accum_steps,
           path_save,
           shuffle,
           device=constants.Device.CPU,
           verbose=1,
           prev_checkpoint=None,
           groups=None,
           soft_groups=None):
    """
    Main training loop.
    """

    t_init = time.time()

    x0 = get_init(data_train, init, rng)
    if isinstance(init, str) and init == constants.Initialization.ZERO:
        ninitfeats = -1
    else:
        ninitfeats = np.where(x0.detach().numpy() > 0)[0].size

    S = Solver(data_train,
               order,
               Nminibatch=Nminibatch,
               x0=x0,
               C=C,
               ftransform=lambda x: torch.sigmoid(2 * x),
               get_train_opt=lambda p: torch.optim.Adam(p, lr_train),
               rng=rng,
               accum_steps=accum_steps,
               shuffle=shuffle,
               groups=groups,
               soft_groups=soft_groups,
               device=device,
               verbose=verbose)
    S = S.to(device)

    S.ninitfeats = ninitfeats
    S.x0 = x0

    if prev_checkpoint:
        S.load_state_dict(prev_checkpoint[constants.Checkpoint.MODEL])
        S.opt_train.load_state_dict(prev_checkpoint[constants.Checkpoint.OPT])
        torch.set_rng_state(prev_checkpoint[constants.Checkpoint.RNG])

    minibatch = S.Ntrain != S.Nminibatch

    f_stop, stop_conds = get_optim_f_stop(maxiter,
                                          maxtime,
                                          dftol_stop,
                                          freltol_stop,
                                          minibatch=minibatch)

    if debug:
        pass
    else:
        f_callback = None
    stop_conds['t'][-1] = time.time() - t_init

    S.train(f_stop=f_stop, f_callback=f_callback)

    return get_checkpoint(S, stop_conds, rng), S
Beispiel #32
0
    def go_evolve():
        torch.manual_seed(seed)
        if not rng is None:
            torch.set_rng_state(rng)
        device = torch.device('cuda') if cuda_ok else torch.device('cpu')

        adam = task.adam
        population = task.population
        adam = (adam.to(device) if not task.adam is None else None)
        survivors = (
            # [ individual .to (device) for individual in population ] if not population is None else
            population
            if not population is None else [bloodline(original(adam))])

        pool_log_file = os.path.join(out_path, 'pool.log')
        sample_log_file = os.path.join(out_path, 'sample.log')
        habitat = parallel_habitat(*[map_name, parallelism],
                                   batch_size=batch_size,
                                   frame_skip=frame_skip,
                                   distortion=distortion,
                                   max_steps=max_steps,
                                   pool_log_file=pool_log_file,
                                   sample_log_file=sample_log_file)

        for iteration in range(iteration_offset, iterations):
            task.iteration_offset = iteration

            # TODO: move models to cuda per batch only
            with co_(task.algo.evolve(habitat)) as (evolution, send):
                for stage, progress in evolution:
                    if stage == 'generation':
                        print(
                            '--------------------------------------------------------------------------------'
                        )
                        print('Generation: {}'.format(iteration + 1))
                        print(
                            '--------------------------------------------------------------------------------'
                        )
                        send(generation_from(survivors))
                    elif stage == 'pre-elites':
                        pre_elites_survivors, population_size = progress

                        survivors = []
                        i = 0
                        # TODO: whatif log_interval > batch_size?
                        for batch in chunks(pre_elites_survivors, batch_size):
                            batch = list(batch)
                            for subbatch in chunks(iter(batch), log_interval):
                                subbatch = list(subbatch)
                                *_, new_survivors = [
                                    desiderata(character)
                                    for character in subbatch
                                ]
                                i += len(subbatch)
                                survivors = new_survivors
                                print(
                                    'Generation: {} Pre-elites [{}/{} ({:.0f}%)]\tMin: {:.6f}\tMedian: {:.6f}\tMax: {:.6f}\tAverage: {:.6f}'
                                    .format(
                                        iteration + 1, i, population_size,
                                        100. * i / population_size,
                                        survivors[-1].fitness,
                                        survivors[len(survivors) // 2].fitness,
                                        survivors[0].fitness,
                                        sum([
                                            individual.fitness
                                            for individual in survivors
                                        ]) / len(survivors)))
                        send(generation_from(survivors))
                    elif stage == 'elites':
                        elites_survivors, population_size = progress

                        survivors = []
                        i = 0
                        for batch in chunks(elites_survivors, batch_size):
                            batch = list(batch)
                            for subbatch in chunks(iter(batch), log_interval):
                                subbatch = list(subbatch)
                                *_, new_survivors = [
                                    desiderata(character)
                                    for character in subbatch
                                ]
                                i += len(subbatch)
                                survivors = new_survivors
                                print(
                                    'Generation: {} Elites [{}/{} ({:.0f}%)]\tMin: {:.6f}\tMedian: {:.6f}\tMax: {:.6f}\tAverage: {:.6f}'
                                    .format(
                                        iteration + 1, i, population_size,
                                        100. * i / population_size,
                                        survivors[-1].fitness,
                                        survivors[len(survivors) // 2].fitness,
                                        survivors[0].fitness,
                                        sum([
                                            individual.fitness
                                            for individual in survivors
                                        ]) / len(survivors)))
                        send(generation_from(survivors))

            torch.save(save_task(task),
                       file('task_' + str(iteration + 1) + '.pt'))
            torch.save(save_population(survivors),
                       file('elites_' + str(iteration + 1) + '.pt'))

            desiderata(
                sample_visualization(
                    survivors[0], habitat,
                    file('sample_champion_' + str(iteration + 1) + '.mp4')))
            desiderata(
                sample_visualization(
                    survivors[1], habitat,
                    file('sample_first-runner_' + str(iteration + 1) +
                         '.mp4')))
            desiderata(
                sample_visualization(
                    survivors[2], habitat,
                    file('sample_second-runner_' + str(iteration + 1) +
                         '.mp4')))
def set_rng_state(state: Dict[str, Any]) -> None:
    torch.set_rng_state(state["torch_rng_state"])
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(state["cuda_rng_state"])
Beispiel #34
0
def validate():
    # Store random state
    cpu_rng_state_tr = torch.get_rng_state()
    gpu_rng_state_tr = torch.cuda.get_rng_state()

    # Set random stae
    torch.manual_seed(123)

    # ###
    sum_losses_va = OrderedDict([(loss_name, 0) for loss_name in loss_funcs])

    count_all_va = 0

    # In validation, set netG.eval()
    netG.eval()
    netD.eval()
    num_batches_va = len(iterator_va)
    with torch.set_grad_enabled(False):
        for i_batch, batch in enumerate(iterator_va):

            # voc.shape=(bs, feat_dim, num_frames)
            voc = batch
            voc = voc.cuda()
            voc = (voc - mean) / std

            bs, _, nf = voc.size()

            # ### Train generator ###
            z = torch.zeros((bs, z_dim, nf)).normal_(0, 1).float().cuda()

            fake_voc = netG(z)

            gloss = torch.mean(torch.abs(netD(fake_voc) - fake_voc))

            # ### Train discriminators ###
            real_dloss = torch.mean(torch.abs(netD(voc) - voc))
            fake_dloss = torch.mean(
                torch.abs(netD(fake_voc.detach()) - fake_voc.detach()))

            dloss = real_dloss - k * fake_dloss

            # ### Convergence ###
            _, convergence = recorder(real_dloss, fake_dloss, update_k=False)

            # ### Losses ###
            losses = OrderedDict([
                ('G', gloss),
                ('D', dloss),
                ('RealD', real_dloss),
                ('FakeD', fake_dloss),
                ('Convergence', convergence),
            ])

            # ### Misc ###
            count_all_va += bs

            # Accumulate losses
            losses_va = OrderedDict([(loss_name, lo.item())
                                     for loss_name, lo in losses.items()])

            for loss_name, lo in losses_va.items():
                sum_losses_va[loss_name] += lo * bs

            if i_batch % 10 == 0:
                print('{}/{}'.format(i_batch, num_batches_va))
    mean_losses_va = OrderedDict([(loss_name, slo / count_all_va)
                                  for loss_name, slo in sum_losses_va.items()])

    # Restore rng state
    torch.set_rng_state(cpu_rng_state_tr)
    torch.cuda.set_rng_state(gpu_rng_state_tr)

    return mean_losses_va
Beispiel #35
0
def train_CcGAN(kernel_sigma,
                kappa,
                train_images,
                train_labels,
                netG,
                netD,
                net_y2h,
                save_images_folder,
                save_models_folder=None,
                clip_label=False):
    '''
    Note that train_images are not normalized to [-1,1]
    '''

    netG = netG.to(device)
    netD = netD.to(device)
    net_y2h = net_y2h.to(device)
    net_y2h.eval()

    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=lr_g,
                                  betas=(0.5, 0.999))
    optimizerD = torch.optim.Adam(netD.parameters(),
                                  lr=lr_d,
                                  betas=(0.5, 0.999))

    if save_models_folder is not None and resume_niters > 0:
        save_file = save_models_folder + "/CcGAN_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(
            threshold_type, resume_niters)
        checkpoint = torch.load(save_file)
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
    #end if

    #################
    unique_train_labels = np.sort(np.array(list(set(train_labels))))

    # printed images with labels between the 5-th quantile and 95-th quantile of training labels
    n_row = 10
    n_col = n_row
    z_fixed = torch.randn(n_row * n_col, dim_gan, dtype=torch.float).to(device)
    start_label = np.quantile(train_labels, 0.05)
    end_label = np.quantile(train_labels, 0.95)
    selected_labels = np.linspace(start_label, end_label, num=n_row)
    y_fixed = np.zeros(n_row * n_col)
    for i in range(n_row):
        curr_label = selected_labels[i]
        for j in range(n_col):
            y_fixed[i * n_col + j] = curr_label
    print(y_fixed)
    y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,
                                                               1).to(device)

    start_time = timeit.default_timer()
    for niter in range(resume_niters, niters):
        '''  Train Discriminator   '''
        ## randomly draw batch_size_disc y's from unique_train_labels
        batch_target_labels_in_dataset = np.random.choice(unique_train_labels,
                                                          size=batch_size_max,
                                                          replace=True)
        ## add Gaussian noise; we estimate image distribution conditional on these labels
        batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_max)
        batch_target_labels_with_epsilon = batch_target_labels_in_dataset + batch_epsilons
        if clip_label:
            batch_target_labels_with_epsilon = np.clip(
                batch_target_labels_with_epsilon, 0.0, 1.0)

        batch_target_labels = batch_target_labels_with_epsilon[
            0:batch_size_disc]

        ## find index of real images with labels in the vicinity of batch_target_labels
        ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
        batch_real_indx = np.zeros(
            batch_size_disc, dtype=int
        )  #index of images in the datata; the labels of these images are in the vicinity
        batch_fake_labels = np.zeros(batch_size_disc)

        for j in range(batch_size_disc):
            ## index for real images
            if threshold_type == "hard":
                indx_real_in_vicinity = np.where(
                    np.abs(train_labels - batch_target_labels[j]) <= kappa)[0]
            else:
                # reverse the weight function for SVDL
                indx_real_in_vicinity = np.where(
                    (train_labels - batch_target_labels[j]
                     )**2 <= -np.log(nonzero_soft_weight_threshold) / kappa)[0]

            ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
            while len(indx_real_in_vicinity) < 1:
                batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
                batch_target_labels[
                    j] = batch_target_labels_in_dataset[j] + batch_epsilons_j
                if clip_label:
                    batch_target_labels = np.clip(batch_target_labels, 0.0,
                                                  1.0)
                ## index for real images
                if threshold_type == "hard":
                    indx_real_in_vicinity = np.where(
                        np.abs(train_labels -
                               batch_target_labels[j]) <= kappa)[0]
                else:
                    # reverse the weight function for SVDL
                    indx_real_in_vicinity = np.where(
                        (train_labels - batch_target_labels[j])**2 <=
                        -np.log(nonzero_soft_weight_threshold) / kappa)[0]
            #end while len(indx_real_in_vicinity)<1

            assert len(indx_real_in_vicinity) >= 1

            batch_real_indx[j] = np.random.choice(indx_real_in_vicinity,
                                                  size=1)[0]

            ## labels for fake images generation
            if threshold_type == "hard":
                lb = batch_target_labels[j] - kappa
                ub = batch_target_labels[j] + kappa
            else:
                lb = batch_target_labels[j] - np.sqrt(
                    -np.log(nonzero_soft_weight_threshold) / kappa)
                ub = batch_target_labels[j] + np.sqrt(
                    -np.log(nonzero_soft_weight_threshold) / kappa)
            lb = max(0.0, lb)
            ub = min(ub, 1.0)
            assert lb <= ub
            assert lb >= 0 and ub >= 0
            assert lb <= 1 and ub <= 1
            batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]
        #end for j

        ## draw the real image batch from the training set
        batch_real_images = train_images[batch_real_indx]
        assert batch_real_images.max() > 1
        batch_real_labels = train_labels[batch_real_indx]
        batch_real_labels = torch.from_numpy(batch_real_labels).type(
            torch.float).to(device)

        ## normalize real images
        trainset = IMGs_dataset(batch_real_images, labels=None, normalize=True)
        train_dataloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size_disc, shuffle=False, num_workers=8)
        train_dataloader = iter(train_dataloader)
        batch_real_images = train_dataloader.next()
        assert len(batch_real_images) == batch_size_disc
        batch_real_images = batch_real_images.type(torch.float).to(device)
        assert batch_real_images.max().item() <= 1

        ## generate the fake image batch
        batch_fake_labels = torch.from_numpy(batch_fake_labels).type(
            torch.float).to(device)
        z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).to(device)
        batch_fake_images = netG(z, net_y2h(batch_fake_labels))

        ## target labels on gpu
        batch_target_labels = torch.from_numpy(batch_target_labels).type(
            torch.float).to(device)

        ## weight vector
        if threshold_type == "soft":
            real_weights = torch.exp(
                -kappa *
                (batch_real_labels - batch_target_labels)**2).to(device)
            fake_weights = torch.exp(
                -kappa *
                (batch_fake_labels - batch_target_labels)**2).to(device)
        else:
            real_weights = torch.ones(batch_size_disc,
                                      dtype=torch.float).to(device)
            fake_weights = torch.ones(batch_size_disc,
                                      dtype=torch.float).to(device)
        #end if threshold type

        # forward pass
        real_dis_out = netD(batch_real_images, net_y2h(batch_target_labels))
        fake_dis_out = netD(batch_fake_images.detach(),
                            net_y2h(batch_target_labels))

        if loss_type == "vanilla":
            real_dis_out = torch.nn.Sigmoid()(real_dis_out)
            fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
            d_loss_real = -torch.log(real_dis_out + 1e-20)
            d_loss_fake = -torch.log(1 - fake_dis_out + 1e-20)
        elif loss_type == "hinge":
            d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
            d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)

        d_loss = torch.mean(
            real_weights.view(-1) * d_loss_real.view(-1)) + torch.mean(
                fake_weights.view(-1) * d_loss_fake.view(-1))

        optimizerD.zero_grad()
        d_loss.backward()
        optimizerD.step()
        '''  Train Generator   '''
        netG.train()

        # generate fake images
        batch_target_labels = batch_target_labels_with_epsilon[
            0:batch_size_gene]
        batch_target_labels = torch.from_numpy(batch_target_labels).type(
            torch.float).to(device)

        z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).to(device)
        batch_fake_images = netG(z, net_y2h(batch_target_labels))

        # loss
        dis_out = netD(batch_fake_images, net_y2h(batch_target_labels))
        if loss_type == "vanilla":
            dis_out = torch.nn.Sigmoid()(dis_out)
            g_loss = -torch.mean(torch.log(dis_out + 1e-20))
        elif loss_type == "hinge":
            g_loss = -dis_out.mean()

        # backward
        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()

        # print loss
        if (niter + 1) % 20 == 0:
            print(
                "CcGAN: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]"
                % (niter + 1, niters, d_loss.item(), g_loss.item(),
                   real_dis_out.mean().item(), fake_dis_out.mean().item(),
                   timeit.default_timer() - start_time))

        if (niter + 1) % 100 == 0:
            netG.eval()
            with torch.no_grad():
                gen_imgs = netG(z_fixed, net_y2h(y_fixed))
                gen_imgs = gen_imgs.detach().cpu()
                save_image(gen_imgs.data,
                           save_images_folder + '/{}.png'.format(niter + 1),
                           nrow=n_row,
                           normalize=True)

        if save_models_folder is not None and (
            (niter + 1) % save_niters_freq == 0 or (niter + 1) == niters):
            save_file = save_models_folder + "/CcGAN_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(
                threshold_type, niter + 1)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save(
                {
                    'netG_state_dict': netG.state_dict(),
                    'netD_state_dict': netD.state_dict(),
                    'optimizerG_state_dict': optimizerG.state_dict(),
                    'optimizerD_state_dict': optimizerD.state_dict(),
                    'rng_state': torch.get_rng_state()
                }, save_file)
    #end for niter
    return netG, netD