コード例 #1
0
 def __call__(self, module, module_in):
     print('Module inputs before')
     print(module_in)
     for p_name, a in zip(self.param_names, self.amount):
         prune.random_unstructured(module, p_name, amount=a)
     print('Module buffers')
     print(list(module.named_buffers()))
     print('Module params')
     print(list(module.named_parameters()))
     print('Module inputs')
     print(module_in)
コード例 #2
0
    def load(self, filename):

        """
        Loads a trained model from files.
        Parameters
        ----------
        filename : str
            Path to the files. '_settings.json' and '_state_dict.pl' will be added.
        Returns
        -------
            None
        """

        logger.info("Loading model from %s", filename)

        # Load settings and create model
        logger.debug("Loading settings from %s_settings.json", filename)
        with open(filename + "_settings.json", "r") as f:
            settings = json.load(f)
        self._unwrap_settings(settings)
        self._create_model()

        # Load scaling
        try:
            self.x_scaling_means = np.load(filename + "_x_means.npy")
            self.x_scaling_stds = np.load(filename + "_x_stds.npy")
            logger.debug(
                "  Found input scaling information: means %s, stds %s", self.x_scaling_means, self.x_scaling_stds
            )
        except FileNotFoundError:
            logger.warning("Scaling information not found in %s", filename)
            self.x_scaling_means = None
            self.x_scaling_stds = None

        module = self.model.ll1
        print("before pruning")
        print(list(module.named_parameters()))
        print(list(module.named_buffers()))
 
        prune.random_unstructured(module, name="weight", amount=0.3)
        print("before pruning")
        print(list(module.named_parameters()))
        print(list(module.named_buffers()))
        print(self.model.state_dict().keys())
 
        # Load state dict
        logger.debug("Loading state dictionary from %s_state_dict.pt", filename)
        print(self.model.state_dict().keys())       
        self.model.load_state_dict(torch.load(filename + "_state_dict.pt", map_location="cpu"))
コード例 #3
0
def random_prune_model(model, ratio):
    module = model.fc[0]
    pruned_w = prune.random_unstructured(
        module, name='weight', amount=ratio
    )  #prune.custom_from_mask(module, name='weight', mask=mag_w_mask)
    model.fc[0] = pruned_w
    return model
コード例 #4
0
def pruner(model, amount, random=False):
    """
    (amount) total amount of desired sparsity
    """
    for name, module in model.named_modules():
        # prune declared amount of connections in all 2D-conv & Linear layers
        if isinstance(module, torch.nn.Conv2d) or isinstance(
                module, torch.nn.Linear):
            if random:
                prune.random_unstructured(module, name='weight', amount=amount)

            else:
                prune.l1_unstructured(module, name='weight', amount=amount)
                #prune.remove(module, 'weight') # make it permanent

    return model
コード例 #5
0
ファイル: pruning.py プロジェクト: jatinarora2409/darts
def prune_darts(model, pruning_percentage):
    for modules in model.children():
        if not isinstance(modules, nn.AdaptiveAvgPool3d):
            for module in modules:
                if not isinstance(module, Cell):
                    # print(list(module.named_parameters()))
                    prune.random_unstructured(module,
                                              name="weight",
                                              amount=pruning_percentage)
                    # print(list(module.named_parameters()))
                    # print("Not Cell")
                else:
                    # print(module)
                    for cell_module in module.children():
                        # print("Cell")
                        if isinstance(cell_module, ReLUConvBN):
                            # print("ReLUConvBN")
                            for ReLU_List in cell_module.children():
                                for ReLU_module in ReLU_List:
                                    if isinstance(ReLU_module, nn.Conv2d):
                                        # if ReLU_module.kernel_size[0] == 1:
                                        #     continue
                                        prune.l1_unstructured(
                                            ReLU_module,
                                            name="weight",
                                            amount=pruning_percentage)
                        elif isinstance(cell_module, nn.ModuleList):
                            # print("nn.ModuleList")
                            for innerModules in cell_module.children():
                                for seqItem in innerModules.children():
                                    for layer in seqItem:
                                        if isinstance(layer, nn.Conv2d):
                                            # if layer.kernel_size[0] == 1:
                                            #     continue
                                            prune.l1_unstructured(
                                                layer,
                                                name="weight",
                                                amount=pruning_percentage)
コード例 #6
0
def prune_random_unstructured(net_creator, imagenet_path, batch_size):
    for i in range(1, 5):
        for idx in range(2, 18):
            net = net_creator()
            module = net.features[idx].conv
            amount = i / 10
            prune.random_unstructured(module[0][0],
                                      name='weight',
                                      amount=amount)
            prune.random_unstructured(module[1][0],
                                      name='weight',
                                      amount=amount)

            prune.remove(module[0][0], 'weight')
            prune.remove(module[1][0], 'weight')

            time0 = time.time()
            result = test_imagenet(net, imagenet_path, batch_size)
            speed = 1000 / (time.time() - time0)

            with open("log.txt", 'a+') as file:
                file.write(
                    "method: rand_unstr - module: {} - prune amount: {:.0%} - accuracy: {:.2f} - speed: {:.2f} \n"
                    .format(idx, amount, result, speed))
コード例 #7
0
x_train, x_test, y_train, y_test = train_test_split(x,
                                                    y,
                                                    test_size=0.2,
                                                    random_state=42)

x_train = torch.FloatTensor(x_train)
y_train = torch.FloatTensor(blob_label(y_train, 0, [0]))
y_train = torch.FloatTensor(blob_label(y_train, 1, [1, 2, 3]))

x_test = torch.FloatTensor(x_test)
y_test = torch.FloatTensor(blob_label(y_test, 0, [0]))
y_test = torch.FloatTensor(blob_label(y_test, 1, [1, 2, 3]))

from torch.nn.utils import prune
"""
model = Feedforward(1024, 512)

# Prune weight
prune.random_unstructured(model.fc1, name="weight", amount=0.95)

criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

model.eval()
y_pred = model(x_test)
before_train = criterion(y_pred.squeeze(), y_test)
print('Test loss before training' , before_train.item())


model.train()
コード例 #8
0
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

module = model.conv1
print(list(module.named_parameters()))

prune.random_unstructured(module, name="weight", amount=0.3)
print(list(module.named_parameters()))
コード例 #9
0
 def perform_pruning(self):
     prune.random_unstructured(module=self.fc1, name='weight', amount=0.2)
コード例 #10
0
# NOTE:
#   prune.random_unstructured
#   prune.l1_unstructured
#   prune.random_structured
#   prune.ln_structured
# unstrcutured for weight pruning and structured for channel pruning

# Iteration over named_parameters
for name, mod in pruned_net.named_modules():
    # if hasattr(mod, 'bias'):
    #     prune.random_unstructured(mod, name="bias", amount=0.5)
    # elif hasattr(mod, 'weight'):
    #     prune.random_unstructured(mod, name="weight", amount=0.5)
    # else:
    if isinstance(mod, torch.nn.Conv2d):
        prune.random_unstructured(mod, name="weight", amount=0.5)
    elif isinstance(mod, torch.nn.BatchNorm2d):
        prune.random_unstructured(mod, name="weight", amount=0.5)
        prune.random_unstructured(mod, name="bias", amount=0.5)
    elif isinstance(mod, torch.nn.Linear):
        prune.random_unstructured(mod, name="weight", amount=0.5)
        prune.random_unstructured(mod, name="bias", amount=0.5)


# Apply prune
for name, mod in pruned_net.named_modules():
    if isinstance(mod, torch.nn.Linear):
        prune.remove(mod, name='weight')
        prune.remove(mod, name='bias')
    elif isinstance(mod, torch.nn.BatchNorm2d):
        prune.remove(mod, name='weight')
コード例 #11
0
def prune_model(model, args, type):
    if type == 'train':
        if args.prune_train > 0:
            print('Pruning {} %'.format(args.prune_train * 100))
            if args.prune == 'global': print('Global Pruning')
            elif args.prune == 'l1': print('L1 Pruning')
            elif args.prune == 'random': print('Random Pruning')
            parameters_to_prune = []
            for mod_name, module in list(model.named_modules()):
                # for name, value in list(module.named_parameters()):
                if hasattr(module, 'weight') or hasattr(module, 'weight_mask'):
                    print(mod_name)
                    name = 'weight'
                    print('weights before {:.3f}%'.format(
                        float(torch.sum(module.weight == 0)) * 100 /
                        float(module.weight.nelement())))
                    if args.prune == 'global':
                        parameters_to_prune.append((module, name))
                    elif args.prune == 'l1':
                        prune.l1_unstructured(module,
                                              name=name,
                                              amount=args.prune_train)
                    elif args.prune == 'random':
                        prune.random_unstructured(module,
                                                  name=name,
                                                  amount=args.prune_train)
                    print('weights after {:.3f}%'.format(
                        float(torch.sum(module.weight == 0)) * 100 /
                        float(module.weight.nelement())))
                    # if prune.is_pruned(module):
                    #     prune.remove(module, 'weight')
                    #     print('removed',mod_name)
            if args.prune == 'global':
                prune.global_unstructured(parameters_to_prune,
                                          pruning_method=prune.L1Unstructured,
                                          amount=args.prune_train)
    elif type == 'eval':
        if args.prune_eval > 0:
            print('Pruning {} %'.format(args.prune_eval * 100))
            if args.prune == 'global': print('Global Pruning')
            elif args.prune == 'l1': print('L1 Pruning')
            elif args.prune == 'random': print('Random Pruning')
            parameters_to_prune = []
            for mod_name, module in list(model.named_modules()):
                # for name, value in list(module.named_parameters()):
                if hasattr(module, 'weight') or hasattr(module, 'weight_mask'):
                    print(mod_name)
                    name = 'weight'
                    print('weights before {:.3f}%'.format(
                        float(torch.sum(module.weight == 0)) * 100 /
                        float(module.weight.nelement())))
                    if args.prune == 'global':
                        parameters_to_prune.append((module, name))
                    elif args.prune == 'l1':
                        prune.l1_unstructured(module,
                                              name=name,
                                              amount=args.prune_eval)
                    elif args.prune == 'random':
                        prune.random_unstructured(module,
                                                  name=name,
                                                  amount=args.prune_eval)
                    print('weights after {:.3f}%'.format(
                        float(torch.sum(module.weight == 0)) * 100 /
                        float(module.weight.nelement())))
                    # if prune.is_pruned(module):
                    #     prune.remove(module, 'weight')
                    # print('removed',mod_name)
            if args.prune == 'global':
                prune.global_unstructured(parameters_to_prune,
                                          pruning_method=prune.L1Unstructured,
                                          amount=args.prune_eval)
    countZeroWeights(model)
コード例 #12
0
ファイル: trainer.py プロジェクト: zhaoliang1983x/libxsmm
    def train_step(self, samples, raise_oom=False):
        # prune accordingly
        def to_prune_or_not_to_prune():
            if self.get_num_updates() >= self.prune_start_step and \
                (self.get_num_updates() - self.prune_start_step) % self.pruning_interval == 0 and \
                abs(self.sparsity - self.target_sparsity) > 5e-3 and \
                self.sparsity < self.target_sparsity:
                return True
            return False

        def get_pruning_amount(train_step):
            n = self.num_pruning_steps
            dt = self.pruning_interval
            t_0 = self.prune_start_step
            s_i = 0.
            s_f = self.target_sparsity

            s_t = s_f + (s_i - s_f) * (1 - (train_step - t_0) / (n * dt))**3
            return s_t

        if to_prune_or_not_to_prune():
            # Determine how much to prune
            target_sparsity = get_pruning_amount(self.get_num_updates())

            pruning_amount = (target_sparsity -
                              self.sparsity) / (1. - self.sparsity)

            if pruning_amount > 0:
                for module, _ in self.get_modules_to_prune():
                    if self.prune_type == 'magnitude':
                        prune.l1_unstructured(module,
                                              name="weight",
                                              amount=pruning_amount)
                    else:
                        prune.random_unstructured(module,
                                                  name="weight",
                                                  amount=pruning_amount)
                """
                prune.global_unstructured(
                        self.get_modules_to_prune(),
                        pruning_method=prune.L1Unstructured if self.prune_type == 'magnitude' \
                                else prune.RandomUnstructured,
                                amount=pruning_amount)
                """
            logger.info(
                "NOTE: Weights pruned, type: {}, amount: {}, sparsity: {}".
                format(self.prune_type, pruning_amount, self.sparsity))
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.data_parallel_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():

                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch)

                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.args.distributed_world_size == 1:
                        return None
                else:
                    raise e

            if self.tpu and i < len(samples) - 1:
                # tpu-comment: every XLA operation before marking step is
                # appended to the IR graph, and processing too many batches
                # before marking step can lead to OOM errors.
                # To handle gradient accumulation use case, we explicitly
                # mark step here for every forward pass without a backward pass
                import torch_xla.core.xla_model as xm
                xm.mark_step()

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            train_time = self._local_cumulative_training_time()
            logging_outputs, (
                sample_size, ooms,
                total_train_time) = self._aggregate_logging_outputs(
                    logging_outputs,
                    sample_size,
                    ooms,
                    train_time,
                    ignore=is_dummy_batch,
                )
            self._cumulative_training_time = total_train_time / self.data_parallel_world_size

        overflow = False
        try:
            if self.tpu and self.data_parallel_world_size > 1:
                import torch_xla.core.xla_model as xm
                gradients = xm._fetch_gradients(self.optimizer.optimizer)
                xm.all_reduce('sum',
                              gradients,
                              scale=1.0 / self.data_parallel_world_size)

            with torch.autograd.profiler.record_function("multiply-grads"):
                # multiply gradients by (# GPUs / sample_size) since DDP
                # already normalizes by the number of GPUs. Thus we get
                # (sum_of_gradients / sample_size).
                if not self.args.use_bmuf:
                    self.optimizer.multiply_grads(
                        self.data_parallel_world_size / sample_size)
                elif sample_size > 0:  # BMUF needs to check sample size
                    num = self.data_parallel_world_size if self._sync_stats(
                    ) else 1
                    self.optimizer.multiply_grads(num / sample_size)

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if (not self.args.use_bmuf
                    and self.args.distributed_wrapper != 'SlowMo'
                    and not self.tpu):
                self._check_grad_norms(grad_norm)

            with torch.autograd.profiler.record_function("optimizer"):
                # take an optimization step
                self.optimizer.step()
        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            with NanDetector(self.model):
                self.task.train_step(sample,
                                     self.model,
                                     self.criterion,
                                     self.optimizer,
                                     self.get_num_updates(),
                                     ignore_grad=False)
            raise
        except OverflowError as e:
            overflow = True
            logger.info("NOTE: overflow detected, " + str(e))
            grad_norm = torch.tensor(0.).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
        if hasattr(self.model, 'perform_additional_optimizer_actions'):
            if hasattr(self.optimizer, 'fp32_params'):
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer, self.optimizer.fp32_params)
            else:
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer)

        if not overflow or self.args.distributed_wrapper == 'SlowMo':
            self.set_num_updates(self.get_num_updates() + 1)

            if self.tpu:
                # mark step on TPUs
                import torch_xla.core.xla_model as xm
                xm.mark_step()

                # only log stats every log_interval steps
                # this causes wps to be misreported when log_interval > 1
                logging_output = {}
                if self.get_num_updates() % self.args.log_interval == 0:
                    logging_output = self._reduce_and_log_stats(
                        logging_outputs,
                        sample_size,
                        grad_norm,
                    )

                # log whenever there's an XLA compilation, since these
                # slow down training and may indicate opportunities for
                # optimization
                self._check_xla_compilation()
            else:
                # log stats
                logging_output = self._reduce_and_log_stats(
                    logging_outputs,
                    sample_size,
                    grad_norm,
                )

                # clear CUDA cache to reduce memory fragmentation
                if (self.cuda and self.args.empty_cache_freq > 0 and (
                    (self.get_num_updates() + self.args.empty_cache_freq - 1) %
                        self.args.empty_cache_freq) == 0):
                    torch.cuda.empty_cache()

        if self.args.fp16:
            metrics.log_scalar("loss_scale",
                               self.optimizer.scaler.loss_scale,
                               priority=700,
                               round=0)

        # Log pruning stuff
        metrics.log_scalar("sparsity", self.sparsity, priority=0, round=3)

        metrics.log_stop_time("train_wall")

        return logging_output
コード例 #13
0
# %%
# w2 = snim0.l2.weight.detach().cpu().numpy()
w2 = cmod.readout.weight.detach().cpu().numpy()

plt.imshow(w2)
# plt.plot(w2)

# plt.plot(np.sum(w2, axis=0))
plt.figure()
f=plt.plot(w2)

#%%
import torch.nn.utils.prune as prune

a = prune.random_unstructured(sgqm0.linear1, name="weight", amount=0.3)
sgqm1 = deepcopy(sgqm0)
w = a.weight.detach().cpu().numpy()

nfilt = w.shape[0]
sx,sy = U.get_subplot_dims(nfilt)
plt.figure(figsize=(10,10))
for cc in range(nfilt):
    plt.subplot(sx,sy,cc+1)
    wtmp = np.reshape(w[cc,:], (gd.num_lags, gd.NX*gd.NY))
    bestlag = np.argmax(np.std(wtmp, axis=1))
    # bestlag = 5
    plt.imshow(np.reshape(wtmp[bestlag,:], (gd.NY, gd.NX)))


# %% cnim plot
コード例 #14
0
ファイル: main_mp.py プロジェクト: sanixa/GS-WGAN-custom
def main(args):
    ### config
    global noise_multiplier
    dataset = args.dataset
    num_discriminators = args.num_discriminators
    noise_multiplier = args.noise_multiplier
    z_dim = args.z_dim
    model_dim = args.model_dim
    batchsize = args.batchsize
    L_gp = args.L_gp
    L_epsilon = args.L_epsilon
    critic_iters = args.critic_iters
    latent_type = args.latent_type
    load_dir = args.load_dir
    save_dir = args.save_dir
    if_dp = (args.dp > 0.)
    gen_arch = args.gen_arch
    num_gpus = args.num_gpus

    ### CUDA
    use_cuda = torch.cuda.is_available()
    devices = [
        torch.device("cuda:%d" % i if use_cuda else "cpu")
        for i in range(num_gpus)
    ]
    device0 = devices[0]
    if use_cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    ### Random seed
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    ### Fix noise for visualization
    if latent_type == 'normal':
        fix_noise = torch.randn(10, z_dim)
    elif latent_type == 'bernoulli':
        p = 0.5
        bernoulli = torch.distributions.Bernoulli(torch.tensor([p]))
        fix_noise = bernoulli.sample((10, z_dim)).view(10, z_dim)
    else:
        raise NotImplementedError

    ### Set up models
    print('gen_arch:' + gen_arch)
    if dataset == 'mnist':
        netG = GeneratorDCGAN(z_dim=z_dim, model_dim=model_dim, num_classes=10)
    elif dataset == 'cifar_10':
        netG = GeneratorDCGAN_cifar(z_dim=z_dim,
                                    model_dim=model_dim,
                                    num_classes=10)

    netGS = copy.deepcopy(netG)

    ##prune
    if dataset == 'mnist':
        prune.random_unstructured(netG.fc, name="weight", amount=0.5)
        prune.random_unstructured(netG.deconv1, name="weight", amount=0.5)
        prune.random_unstructured(netG.deconv2, name="weight", amount=0.5)
        prune.random_unstructured(netG.deconv3, name="weight", amount=0.5)

        prune.random_unstructured(netGS.fc, name="weight", amount=0.5)
        prune.random_unstructured(netGS.deconv1, name="weight", amount=0.5)
        prune.random_unstructured(netGS.deconv2, name="weight", amount=0.5)
        prune.random_unstructured(netGS.deconv3, name="weight", amount=0.5)
    elif dataset == 'cifar_10':
        prune.random_unstructured(netG.fc, name="weight", amount=0.5)
        prune.random_unstructured(netG.deconv1, name="weight", amount=0.5)
        prune.random_unstructured(netG.deconv2, name="weight", amount=0.5)
        prune.random_unstructured(netG.deconv3, name="weight", amount=0.5)
        prune.random_unstructured(netG.deconv4, name="weight", amount=0.5)

        prune.random_unstructured(netGS.fc, name="weight", amount=0.5)
        prune.random_unstructured(netGS.deconv1, name="weight", amount=0.5)
        prune.random_unstructured(netGS.deconv2, name="weight", amount=0.5)
        prune.random_unstructured(netGS.deconv3, name="weight", amount=0.5)
        prune.random_unstructured(netGS.deconv4, name="weight", amount=0.5)

    netD_list = []
    for i in range(num_discriminators):
        if dataset == 'mnist':
            netD = DiscriminatorDCGAN()
        elif dataset == 'cifar_10':
            netD = DiscriminatorDCGAN_cifar()
        netD_list.append(netD)

    ### Load pre-trained discriminators
    print("load pre-training...")
    if load_dir is not None:
        for netD_id in range(num_discriminators):
            print('Load NetD ', str(netD_id))
            network_path = os.path.join(load_dir, 'netD_%d' % netD_id,
                                        'netD.pth')
            netD = netD_list[netD_id]
            netD.load_state_dict(torch.load(network_path))

    netG = netG.to(device0)
    for netD_id, netD in enumerate(netD_list):
        device = devices[get_device_id(netD_id, num_discriminators, num_gpus)]
        netD.to(device)

    ### Set up optimizers
    optimizerD_list = []
    for i in range(num_discriminators):
        netD = netD_list[i]
        optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.99))
        optimizerD_list.append(optimizerD)
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.99))

    ### Data loaders
    if dataset == 'mnist' or dataset == 'fashionmnist':
        transform_train = transforms.Compose([
            transforms.CenterCrop((28, 28)),
            transforms.ToTensor(),
            #transforms.Grayscale(),
        ])
    elif dataset == 'cifar_100' or dataset == 'cifar_10':
        transform_train = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),
        ])

    if dataset == 'mnist':
        dataloader = datasets.MNIST
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'MNIST'),
                              train=True,
                              download=True,
                              transform=transform_train)
        IMG_DIM = 784
        NUM_CLASSES = 10
    elif dataset == 'fashionmnist':
        dataloader = datasets.FashionMNIST
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'FashionMNIST'),
                              train=True,
                              download=True,
                              transform=transform_train)
    elif dataset == 'cifar_100':
        dataloader = datasets.CIFAR100
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'CIFAR100'),
                              train=True,
                              download=True,
                              transform=transform_train)
        IMG_DIM = 3072
        NUM_CLASSES = 100
    elif dataset == 'cifar_10':
        IMG_DIM = 1024
        NUM_CLASSES = 10
        dataloader = datasets.CIFAR10
        trainset = dataloader(root=os.path.join(DATA_ROOT, 'CIFAR10'),
                              train=True,
                              download=True,
                              transform=transform_train)
    else:
        raise NotImplementedError

    ###fix sub-training set (fix to 10000 training samples)
    if args.update_train_dataset:
        indices_full = np.arange(len(trainset))
        np.random.shuffle(indices_full)
        indices_10000 = indices_full[:10000]
        np.savetxt('index_10000.txt', indices_10000, fmt='%i')
    indices = np.loadtxt('index_10000.txt', dtype=np.int_)
    trainset = torch.utils.data.Subset(trainset, indices)

    print('creat indices file')
    indices_full = np.arange(len(trainset))
    np.random.shuffle(indices_full)
    #indices_full.dump(os.path.join(save_dir, 'indices.npy'))
    trainset_size = int(len(trainset) / num_discriminators)
    print('Size of the dataset: ', trainset_size)

    input_pipelines = []
    for i in range(num_discriminators):
        start = i * trainset_size
        end = (i + 1) * trainset_size
        indices = indices_full[start:end]
        trainloader = DataLoader(trainset,
                                 batch_size=args.batchsize,
                                 drop_last=False,
                                 num_workers=args.num_workers,
                                 sampler=SubsetRandomSampler(indices))
        #input_data = inf_train_gen(trainloader)
        input_pipelines.append(trainloader)

    if if_dp:
        ### Register hook
        global dynamic_hook_function
        for netD in netD_list:
            netD.conv1.register_backward_hook(master_hook_adder)

    prg_bar = tqdm(range(args.iterations + 1))
    for iters in prg_bar:
        #########################
        ### Update D network
        #########################
        netD_id = np.random.randint(num_discriminators, size=1)[0]
        device = devices[get_device_id(netD_id, num_discriminators, num_gpus)]
        netD = netD_list[netD_id]
        optimizerD = optimizerD_list[netD_id]
        input_data = input_pipelines[netD_id]

        for p in netD.parameters():
            p.requires_grad = True

        for iter_d in range(critic_iters):
            real_data, real_y = next(iter(input_data))
            real_data = real_data.view(-1, IMG_DIM)
            real_data = real_data.to(device)
            real_y = real_y.to(device)
            real_data_v = autograd.Variable(real_data)

            ### train with real
            dynamic_hook_function = dummy_hook
            netD.zero_grad()
            D_real_score = netD(real_data_v, real_y)
            D_real = -D_real_score.mean()

            ### train with fake
            batchsize = real_data.shape[0]
            if latent_type == 'normal':
                noise = torch.randn(batchsize, z_dim).to(device0)
            elif latent_type == 'bernoulli':
                noise = bernoulli.sample(
                    (batchsize, z_dim)).view(batchsize, z_dim).to(device0)
            else:
                raise NotImplementedError
            noisev = autograd.Variable(noise)
            fake = autograd.Variable(netG(noisev, real_y.to(device0)).data)
            inputv = fake.to(device)
            D_fake = netD(inputv, real_y.to(device))
            D_fake = D_fake.mean()

            ### train with gradient penalty
            gradient_penalty = netD.calc_gradient_penalty(
                real_data_v.data, fake.data, real_y, L_gp, device)
            D_cost = D_fake + D_real + gradient_penalty

            ### train with epsilon penalty
            logit_cost = L_epsilon * torch.pow(D_real_score, 2).mean()
            D_cost += logit_cost

            ### update
            D_cost.backward()
            Wasserstein_D = -D_real - D_fake
            optimizerD.step()

        del real_data, real_y, fake, noise, inputv, D_real, D_fake, logit_cost, gradient_penalty
        torch.cuda.empty_cache()

        ############################
        # Update G network
        ###########################
        if if_dp:
            ### Sanitize the gradients passed to the Generator
            dynamic_hook_function = dp_conv_hook
        else:
            ### Only modify the gradient norm, without adding noise
            dynamic_hook_function = modify_gradnorm_conv_hook

        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()

        ### train with sanitized discriminator output
        if latent_type == 'normal':
            noise = torch.randn(batchsize, z_dim).to(device0)
        elif latent_type == 'bernoulli':
            noise = bernoulli.sample(
                (batchsize, z_dim)).view(batchsize, z_dim).to(device0)
        else:
            raise NotImplementedError
        label = torch.randint(0, NUM_CLASSES, [batchsize]).to(device0)
        noisev = autograd.Variable(noise)
        fake = netG(noisev, label)
        #summary(netG, input_data=[noisev,label])
        fake = fake.to(device)
        label = label.to(device)
        G = netD(fake, label)
        G = -G.mean()

        ### update
        G.backward()
        G_cost = G
        optimizerG.step()

        ### update the exponential moving average
        exp_mov_avg(netGS, netG, alpha=0.999, global_step=iters)

        ############################
        ### Results visualization
        ############################
        prg_bar.set_description(
            'iter:{}, G_cost:{:.2f}, D_cost:{:.2f}, Wasserstein:{:.2f}'.format(
                iters,
                G_cost.cpu().data,
                D_cost.cpu().data,
                Wasserstein_D.cpu().data))
        if iters % args.vis_step == 0:
            if dataset == 'mnist':
                generate_image_mnist(iters, netGS, fix_noise, save_dir,
                                     device0)
            elif dataset == 'cifar_100':
                generate_image_cifar100(iters, netGS, fix_noise, save_dir,
                                        device0)
            elif dataset == 'cifar_10':
                generate_image_cifar10(iters, netGS, fix_noise, save_dir,
                                       device0)

        if iters in [1000, 5000, 10000, 20000]:
            ### save model
            ##prune
            if dataset == 'mnist':
                prune.remove(netGS.fc, name="weight")
                prune.remove(netGS.deconv1, name="weight")
                prune.remove(
                    netGS.deconv2,
                    name="weight",
                )
                prune.remove(netGS.deconv3, name="weight")

            elif dataset == 'cifar_10':

                prune.remove(netGS.fc, name="weight")
                prune.remove(netGS.deconv1, name="weight")
                prune.remove(netGS.deconv2, name="weight")
                prune.remove(netGS.deconv3, name="weight")
                prune.remove(netGS.deconv4, name="weight")

            ### save model
            torch.save(netGS.state_dict(),
                       os.path.join(save_dir, 'netGS_%d.pth' % iters))
            torch.save(netD.state_dict(),
                       os.path.join(save_dir, 'netD_%d.pth' % iters))

            ##prune
            if dataset == 'mnist':
                prune.random_unstructured(netGS.fc, name="weight", amount=0.5)
                prune.random_unstructured(netGS.deconv1,
                                          name="weight",
                                          amount=0.5)
                prune.random_unstructured(netGS.deconv2,
                                          name="weight",
                                          amount=0.5)
                prune.random_unstructured(netGS.deconv3,
                                          name="weight",
                                          amount=0.5)

            elif dataset == 'cifar_10':

                prune.random_unstructured(netGS.fc, name="weight", amount=0.5)
                prune.random_unstructured(netGS.deconv1,
                                          name="weight",
                                          amount=0.5)
                prune.random_unstructured(netGS.deconv2,
                                          name="weight",
                                          amount=0.5)
                prune.random_unstructured(netGS.deconv3,
                                          name="weight",
                                          amount=0.5)
                prune.random_unstructured(netGS.deconv4,
                                          name="weight",
                                          amount=0.5)

        del label, fake, noisev, noise, G, G_cost, D_cost
        torch.cuda.empty_cache()
コード例 #15
0
    def train(
        self,
        method,
        x,
        y,
        x0=None,
        x1=None,
        x_val=None,
        y_val=None,
        alpha=1.0,
        optimizer="amsgrad",
        n_epochs=50,
        batch_size=128,
        initial_lr=0.001,
        final_lr=0.0001,
        nesterov_momentum=None,
        validation_split=0.25,
        early_stopping=True,
        scale_inputs=True,
        prune_network = False,
        limit_samplesize=None,
        memmap=False,
        verbose="some",
        scale_parameters=False,
        n_workers=8,
        clip_gradient=None,
        early_stopping_patience=None,
    ):

        """
        Trains the network.
        Parameters
        ----------
        method : str
            The inference method used for training. Allowed values are 'alice', 'alices', 'carl', 'cascal', 'rascal',
            and 'rolr'.
        x : ndarray or str
            Observations, or filename of a pickled numpy array.
        y : ndarray or str
            Class labels (0 = numeerator, 1 = denominator), or filename of a pickled numpy array.
        alpha : float, optional
            Default value: 1.
        optimizer : {"adam", "amsgrad", "sgd"}, optional
            Optimization algorithm. Default value: "amsgrad".
        n_epochs : int, optional
            Number of epochs. Default value: 50.
        batch_size : int, optional
            Batch size. Default value: 128.
        initial_lr : float, optional
            Learning rate during the first epoch, after which it exponentially decays to final_lr. Default value:
            0.001.
        final_lr : float, optional
            Learning rate during the last epoch. Default value: 0.0001.
        nesterov_momentum : float or None, optional
            If trainer is "sgd", sets the Nesterov momentum. Default value: None.
        validation_split : float or None, optional
            Fraction of samples used  for validation and early stopping (if early_stopping is True). If None, the entire
            sample is used for training and early stopping is deactivated. Default value: 0.25.
        early_stopping : bool, optional
            Activates early stopping based on the validation loss (only if validation_split is not None). Default value:
            True.
        scale_inputs : bool, optional
            Scale the observables to zero mean and unit variance. Default value: True.
        memmap : bool, optional.
            If True, training files larger than 1 GB will not be loaded into memory at once. Default value: False.
        verbose : {"all", "many", "some", "few", "none}, optional
            Determines verbosity of training. Default value: "some".
        Returns
        -------
            None
        """

        logger.info("Starting training")
        logger.info("  Method:                 %s", method)
        logger.info("  Batch size:             %s", batch_size)
        logger.info("  Optimizer:              %s", optimizer)
        logger.info("  Epochs:                 %s", n_epochs)
        logger.info("  Learning rate:          %s initially, decaying to %s", initial_lr, final_lr)
        if optimizer == "sgd":
            logger.info("  Nesterov momentum:      %s", nesterov_momentum)
        logger.info("  Validation split:       %s", validation_split)
        logger.info("  Early stopping:         %s", early_stopping)
        logger.info("  Scale inputs:           %s", scale_inputs)
        if limit_samplesize is None:
            logger.info("  Samples:                all")
        else:
            logger.info("  Samples:                %s", limit_samplesize)

        # Load training data
        logger.info("Loading training data")
        memmap_threshold = 1.0 if memmap else None
        x  = load_and_check(x, memmap_files_larger_than_gb=memmap_threshold)
        y  = load_and_check(y, memmap_files_larger_than_gb=memmap_threshold)
        x0 = load_and_check(x0, memmap_files_larger_than_gb=memmap_threshold)
        x1 = load_and_check(x1, memmap_files_larger_than_gb=memmap_threshold)

        # Infer dimensions of problem
        n_samples = x.shape[0]
        n_observables = x.shape[1]
        logger.info("Found %s samples with %s observables", n_samples, n_observables)

        external_validation = x_val is not None and y_val is not None
        if external_validation:
            x_val = load_and_check(x_val, memmap_files_larger_than_gb=memmap_threshold)
            y_val = load_and_check(y_val, memmap_files_larger_than_gb=memmap_threshold)
            logger.info("Found %s separate validation samples", x_val.shape[0])

            assert x_val.shape[1] == n_observables


        # Scale features
        if scale_inputs:
            self.initialize_input_transform(x, overwrite=False)
            x = self._transform_inputs(x)
            if external_validation:
                x_val = self._transform_inputs(x_val)
        else:
            self.initialize_input_transform(x, False, overwrite=False)

        # Features
        if self.features is not None:
            x = x[:, self.features]
            logger.info("Only using %s of %s observables", x.shape[1], n_observables)
            n_observables = x.shape[1]
            if external_validation:
                x_val = x_val[:, self.features]


        # Check consistency of input with model
        if self.n_observables is None:
            self.n_observables = n_observables

        if n_observables != self.n_observables:
            raise RuntimeError(
                "Number of observables does not match model: {} vs {}".format(n_observables, self.n_observables)
            )

        # Data
        data = self._package_training_data(method, x, y)
        if external_validation:
            data_val = self._package_training_data(method, x_val, y_val)
        else:
            data_val = None
        # Create model
        if self.model is None:
            logger.info("Creating model")
            self._create_model()
        if prune_network:
            module = self.model.ll1
            print("before pruning")
            print(list(module.named_parameters()))
            print(list(module.named_buffers()))

            prune.random_unstructured(module, name="weight", amount=0.3)
            print("before pruning")
            print(list(module.named_parameters()))
            print(list(module.named_buffers()))
            print(self.model.state_dict().keys())


        # Losses
        w = len(x0)/len(x1)
        logger.info("Passing weight %s to the loss function to account for imbalanced dataset: ", w)
        loss_functions, loss_labels, loss_weights = get_loss(method + "2", alpha, w)
        # Optimizer
        opt, opt_kwargs = get_optimizer(optimizer, nesterov_momentum)

        # Train model
        logger.info("Training model")
        trainer = RatioTrainer(self.model, n_workers=n_workers)

        result = trainer.train(
            data=data,
            data_val=data_val,
            loss_functions=loss_functions,
            loss_weights=loss_weights,
            loss_labels=loss_labels,
            epochs=n_epochs,
            batch_size=batch_size,
            optimizer=opt,
            optimizer_kwargs=opt_kwargs,
            initial_lr=initial_lr,
            final_lr=final_lr,
            validation_split=validation_split,
            early_stopping=early_stopping,
            verbose=verbose,
            clip_gradient=clip_gradient,
            early_stopping_patience=early_stopping_patience,
        )
        return result