def train_func(self, data, iteration, pbar):
        images, labels = self.preprocess_image(data)
        images = images.tensor

        bs = len(images)

        batched_arcs = get_ddp_attr(self.controller, 'get_sampled_arc')(bs=bs)

        self.gan_model(images=images,
                       labels=labels,
                       z=self.z_train,
                       iteration=iteration,
                       batched_arcs=batched_arcs)

        if iteration % self.train_controller_every_iter == 0:
            get_ddp_attr(self.controller, 'train_controller')(
                G=self.G,
                z=self.z_train,
                y=self.y_train,
                controller=self.controller,
                controller_optim=self.controller_optim,
                iteration=iteration,
                pbar=pbar)

        # Just for monitoring the training processing
        sampled_arc = get_ddp_attr(self.controller, 'get_sampled_arc')()
        sampled_arc = self.get_tensor_of_main_processing(sampled_arc)

        classes_arcs = sampled_arc[[
            0,
        ], ].repeat(self.n_classes, 1)
        self.evaluate_model(classes_arcs=classes_arcs, iteration=iteration)
        comm.synchronize()
    def get_fixed_arc(self, fixed_arc_file, fixed_epoch):
        self.logger.info(
            f'\n\tUsing fixed_arc: {fixed_arc_file}, \n\tfixed_epoch: {fixed_epoch}'
        )

        if os.path.isfile(fixed_arc_file):
            n_classes = self.n_classes
            if fixed_epoch < 0:
                with open(fixed_arc_file) as f:
                    sample_arc = []
                    for _ in range(n_classes):
                        class_arc = f.readline().strip('[\n ]')
                        sample_arc.append(
                            np.fromstring(class_arc, dtype=int, sep=' '))
            else:
                with open(fixed_arc_file) as f:
                    while True:
                        epoch_str = f.readline().strip(': \n')
                        sample_arc = []
                        for _ in range(n_classes):
                            class_arc = f.readline().strip('[\n ]')
                            sample_arc.append(
                                np.fromstring(class_arc, dtype=int, sep=' '))
                        if fixed_epoch == int(epoch_str):
                            break
            sample_arc = np.array(sample_arc)

        elif fixed_arc_file == 'random':
            while True:
                sample_arc = np.random.randint(
                    0,
                    len(get_ddp_attr(self.G, 'ops')),
                    (self.n_classes, get_ddp_attr(self.G, 'num_layers')),
                    dtype=int)
                if not self.is_repetitive_between_classes(
                        sample_arc=sample_arc):
                    break

        elif fixed_arc_file in get_ddp_attr(self.G, 'cfg_ops'):
            # single op
            ops = list(get_ddp_attr(self.G, 'cfg_ops').keys())
            op_idx = ops.index(fixed_arc_file)
            sample_arc = \
              np.ones((self.n_classes, get_ddp_attr(self.G, 'num_layers')), dtype=int) * op_idx
        else:
            raise NotImplemented

        self.logger.info(f'Sample_arc: \n{sample_arc}')
        self.revise_num_parameters(fixed_arc=sample_arc)
        fixed_arc = torch.from_numpy(sample_arc).cuda()
        return fixed_arc
Пример #3
0
    def train_func(self, data, iteration, pbar):
        # images, labels = self.preprocess_image(data)
        # images = images.tensor

        get_ddp_attr(self.controller, 'train_controller')(
            G=self.G,
            z=self.z_train,
            y=self.y_train,
            controller=self.controller,
            controller_optim=self.controller_optim,
            iteration=iteration,
            pbar=pbar)

        # Just for monitoring the training processing
        # sync arcs
        sampled_arc = get_ddp_attr(self.controller, 'get_sampled_arc')()
        sampled_arc = self.get_tensor_of_main_processing(sampled_arc)

        classes_arcs = sampled_arc[[
            0,
        ], ].repeat(self.n_classes, 1)
        self.evaluate_model(classes_arcs=classes_arcs, iteration=iteration)
        comm.synchronize()
    def __init__(self, cfg, **kwargs):
        super(TrainerDenseGANRetrain, self).__init__(cfg=cfg, **kwargs)

        num_layers = get_ddp_attr(self.G, 'num_layers')
        self.arcs = torch.zeros((1, num_layers), dtype=torch.int64).cuda()

        self.G_ema = copy.deepcopy(self.G)
        self.ema = EMA(source=self.G,
                       target=self.G_ema,
                       decay=0.999,
                       start_itr=0)
        self.models.update({'G_ema': self.G_ema})

        pass
    def __init__(self, cfg, **kwargs):
        super(TrainerDenseGANEvaluate, self).__init__(cfg=cfg, **kwargs)

        self.ckpt_dir = get_attr_kwargs(cfg.trainer, 'ckpt_dir', **kwargs)
        self.ckpt_epoch = get_attr_kwargs(cfg.trainer, 'ckpt_epoch', **kwargs)
        self.iter_every_epoch = get_attr_kwargs(cfg.trainer,
                                                'iter_every_epoch', **kwargs)

        self.G_ema = copy.deepcopy(self.G)
        self.models.update({'G_ema': self.G_ema})
        eval_ckpt = self._get_ckpt_path(ckpt_dir=self.ckpt_dir,
                                        ckpt_epoch=self.ckpt_epoch,
                                        iter_every_epoch=self.iter_every_epoch)
        self._load_G(eval_ckpt)

        num_layers = get_ddp_attr(self.G, 'num_layers')
        self.arcs = torch.zeros((1, num_layers), dtype=torch.int64).cuda()

        pass
Пример #6
0
    def train_controller(self, searched_cnn, valid_dataset_iter,
                         preprocess_image_func, controller, controller_optim,
                         iteration, pbar):
        """

    :param controller: for ddp training
    :return:
    """
        if comm.is_main_process() and iteration % 1000 == 0:
            pbar.set_postfix_str("ClsControllerRLAlphaFair")

        meter_dict = {}

        controller.train()
        controller.zero_grad()

        sampled_arcs = controller()
        sample_entropy = get_ddp_attr(controller, 'sample_entropy')
        sample_log_prob = get_ddp_attr(controller, 'sample_log_prob')

        val_data = next(valid_dataset_iter)
        bs = len(val_data)
        batched_arcs = sampled_arcs.repeat(bs, 1)

        top1 = AverageMeter()
        for i in range(self.num_aggregate):
            val_data = next(valid_dataset_iter)
            val_X, val_y = preprocess_image_func(val_data, device=self.device)
            val_X = val_X.tensor
            with torch.set_grad_enabled(False):
                logits = searched_cnn(val_X, batched_arcs=batched_arcs)
                prec1, = top_accuracy(output=logits, target=val_y, topk=(1, ))
                top1.update(prec1.item(), bs)

        reward_g = top1.avg
        meter_dict['reward_g'] = reward_g

        # detach to make sure that gradients aren't backpropped through the reward
        reward = torch.tensor(reward_g).cuda()
        sample_entropy_mean = sample_entropy.mean()
        meter_dict['sample_entropy'] = sample_entropy_mean.item()
        reward += self.entropy_weight * sample_entropy_mean

        if self.baseline is None:
            baseline = torch.tensor(reward_g)
        else:
            baseline = self.baseline - (1 - self.bl_dec) * (self.baseline -
                                                            reward)
            # detach to make sure that gradients are not backpropped through the baseline
            baseline = baseline.detach()

        sample_log_prob_mean = sample_log_prob.mean()
        meter_dict['sample_log_prob'] = sample_log_prob_mean.item()
        loss = -1 * sample_log_prob_mean * (reward - baseline)

        meter_dict['reward'] = reward.item()
        meter_dict['baseline'] = baseline.item()
        meter_dict['loss'] = loss.item()

        loss.backward(retain_graph=False)

        grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(),
                                                   self.child_grad_bound)
        meter_dict['grad_norm'] = grad_norm

        controller_optim.step()

        baseline_list = comm.all_gather(baseline)
        baseline_mean = sum(map(lambda v: v.item(),
                                baseline_list)) / len(baseline_list)
        baseline.fill_(baseline_mean)
        self.baseline = baseline

        if iteration % self.log_every_iter == 0:
            self.print_distribution(iteration=iteration,
                                    log_prob=False,
                                    print_interval=10)
            default_dicts = collections.defaultdict(dict)
            for meter_k, meter in meter_dict.items():
                if meter_k in ['reward', 'baseline']:
                    default_dicts['reward_baseline'][meter_k] = meter
                else:
                    default_dicts[meter_k][meter_k] = meter
            summary_defaultdict2txtfig(default_dict=default_dicts,
                                       prefix='train_controller',
                                       step=iteration,
                                       textlogger=self.myargs.textlogger)
        comm.synchronize()
        return
Пример #7
0
    def train_controller(self, G, z, y, controller, controller_optim,
                         iteration, pbar):
        """

    :param controller: for ddp training
    :return:
    """
        if comm.is_main_process() and iteration % 1000 == 0:
            pbar.set_postfix_str("ControllerRLAlpha")

        meter_dict = {}

        G.eval()
        controller.train()

        controller.zero_grad()

        sampled_arcs = controller(iteration)

        sample_entropy = get_ddp_attr(controller, 'sample_entropy')
        sample_log_prob = get_ddp_attr(controller, 'sample_log_prob')

        pool_list, logits_list = [], []
        for i in range(self.num_aggregate):
            z_samples = z.sample().to(self.device)
            y_samples = y.sample().to(self.device)
            with torch.set_grad_enabled(False):
                batched_arcs = sampled_arcs[y_samples]
                x = G(z=z_samples, y=y_samples, batched_arcs=batched_arcs)

            pool, logits = self.FID_IS.get_pool_and_logits(x)

            # pool_list.append(pool)
            logits_list.append(logits)

        # pool = np.concatenate(pool_list, 0)
        logits = np.concatenate(logits_list, 0)

        reward_g, _ = self.FID_IS.calculate_IS(logits)
        meter_dict['reward_g'] = reward_g

        # detach to make sure that gradients aren't backpropped through the reward
        reward = torch.tensor(reward_g).cuda()
        sample_entropy_mean = sample_entropy.mean()
        meter_dict['sample_entropy'] = sample_entropy_mean.item()
        reward += self.entropy_weight * sample_entropy_mean

        if self.baseline is None:
            baseline = torch.tensor(reward_g)
        else:
            baseline = self.baseline - (1 - self.bl_dec) * (self.baseline -
                                                            reward)
            # detach to make sure that gradients are not backpropped through the baseline
            baseline = baseline.detach()

        sample_log_prob_mean = sample_log_prob.mean()
        meter_dict['sample_log_prob'] = sample_log_prob_mean.item()
        loss = -1 * sample_log_prob_mean * (reward - baseline)

        meter_dict['reward'] = reward.item()
        meter_dict['baseline'] = baseline.item()
        meter_dict['loss'] = loss.item()

        loss.backward(retain_graph=False)

        grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(),
                                                   self.child_grad_bound)
        meter_dict['grad_norm'] = grad_norm

        controller_optim.step()

        baseline_list = comm.all_gather(baseline)
        baseline_mean = sum(map(lambda v: v.item(),
                                baseline_list)) / len(baseline_list)
        baseline.fill_(baseline_mean)
        self.baseline = baseline

        if iteration % self.log_every_iter == 0:
            self.print_distribution(iteration=iteration, print_interval=10)
            if len(sampled_arcs) <= 10:
                self.logger.info('\nsampled arcs: \n%s' %
                                 sampled_arcs.cpu().numpy())
            self.myargs.textlogger.logstr(
                iteration,
                sampled_arcs='\n' +
                np.array2string(sampled_arcs.cpu().numpy(), threshold=np.inf))
            default_dicts = collections.defaultdict(dict)
            for meter_k, meter in meter_dict.items():
                if meter_k in ['reward', 'baseline']:
                    default_dicts['reward_baseline'][meter_k] = meter
                else:
                    default_dicts[meter_k][meter_k] = meter
            summary_defaultdict2txtfig(default_dict=default_dicts,
                                       prefix='train_controller',
                                       step=iteration,
                                       textlogger=self.myargs.textlogger)
        comm.synchronize()
        return