def init_data_optim(self):
        self.params = []
        state = self.state
        optim_lr = state.lr

        # labels
        self.labels = []
        distill_label = torch.arange(state.num_classes, dtype=torch.long, device=state.device) \
                             .repeat(state.distilled_images_per_class_per_step, 1)  # [[0, 1, 2, ...], [0, 1, 2, ...]]
        distill_label = distill_label.t().reshape(-1)  # [0, 0, ..., 1, 1, ...]
        for _ in range(self.num_data_steps):
            self.labels.append(distill_label)
        self.all_labels = torch.cat(self.labels)

        # data
        self.data = []
        for _ in range(self.num_data_steps):
            distill_data = torch.randn(self.num_per_step,
                                       state.nc,
                                       state.input_size,
                                       state.input_size,
                                       device=state.device,
                                       requires_grad=True)
            self.data.append(distill_data)
            self.params.append(distill_data)

        # lr

        # undo the softplus + threshold
        raw_init_distill_lr = torch.tensor(state.distill_lr,
                                           device=state.device)
        raw_init_distill_lr = raw_init_distill_lr.repeat(self.T, 1)
        self.raw_distill_lrs = raw_init_distill_lr.expm1_().log_(
        ).requires_grad_()
        self.params.append(self.raw_distill_lrs)

        assert len(self.params) > 0, "must have at least 1 parameter"

        # now all the params are in self.params, sync if using distributed
        if state.distributed:
            broadcast_coalesced(self.params)
            logging.info("parameters broadcast done!")

        self.optimizer = optim.Adam(self.params,
                                    lr=state.lr,
                                    betas=(0.5, 0.999))
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=state.decay_epochs,
            gamma=state.decay_factor)
        for p in self.params:
            p.grad = torch.zeros_like(p)
    def init_data_optim(self):
        self.params = []
        state = self.state
        optim_lr = state.lr
        req_lbl_grad = not state.static_labels
        # labels
        self.labels = []

        #distill_label = distill_label.t().reshape(-1)  # [0, 0, ..., 1, 1, ...]
        #distill_label = torch.nn.Softmax(distill_label, dim=1)
        for _ in range(self.num_data_steps):
            if state.random_init_labels:
                distill_label = distillation_label_initialiser(
                    state, self.num_per_step, torch.float, req_lbl_grad)
            else:
                if state.num_classes == 2:
                    dl_array = [[i == j for i in range(1)]
                                for j in state.init_labels
                                ] * state.distilled_images_per_class_per_step
                else:
                    dl_array = [[i == j for i in range(state.num_classes)]
                                for j in state.init_labels
                                ] * state.distilled_images_per_class_per_step

                distill_label = torch.tensor(dl_array,
                                             dtype=torch.float,
                                             requires_grad=req_lbl_grad,
                                             device=state.device)

                #distill_label = self.one_hot_embedding(distill_label, state.num_classes)

            if not state.static_labels:
                self.labels.append(distill_label)
                self.params.append(distill_label)
            else:
                self.labels.append(distill_label)
        self.all_labels = torch.cat(self.labels)

        # data
        self.data = []
        for _ in range(self.num_data_steps):
            if state.textdata:
                distill_data = torch.randn(
                    self.num_per_step,
                    state.nc,
                    state.input_size,
                    state.ninp,
                    device=state.device,
                    requires_grad=(not state.freeze_data))
            else:
                distill_data = torch.randn(
                    self.num_per_step,
                    state.nc,
                    state.input_size,
                    state.input_size,
                    device=state.device,
                    requires_grad=(not state.freeze_data))
                #distill_data = torch.randint(2,(self.num_per_step, state.nc, state.input_size, state.input_size),
                #                       device=state.device, requires_grad=(not state.freeze_data), dtype=torch.float)
            self.data.append(distill_data)
            if not state.freeze_data:
                self.params.append(distill_data)

        # lr

        # undo the softplus + threshold
        raw_init_distill_lr = torch.tensor(state.distill_lr,
                                           device=state.device)
        raw_init_distill_lr = raw_init_distill_lr.repeat(self.T, 1)
        self.raw_distill_lrs = raw_init_distill_lr.expm1_().log_(
        ).requires_grad_()
        self.params.append(self.raw_distill_lrs)

        assert len(self.params) > 0, "must have at least 1 parameter"

        # now all the params are in self.params, sync if using distributed
        if state.distributed:
            broadcast_coalesced(self.params)
            logging.info("parameters broadcast done!")

        self.optimizer = optim.Adam(self.params,
                                    lr=state.lr,
                                    betas=(0.5, 0.999))
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=state.decay_epochs,
            gamma=state.decay_factor)
        for p in self.params:
            p.grad = torch.zeros_like(p)