def train_func(self, data, iteration, pbar): images, labels = self.preprocess_image(data) images = images.tensor bs = len(images) batched_arcs = get_ddp_attr(self.controller, 'get_sampled_arc')(bs=bs) self.gan_model(images=images, labels=labels, z=self.z_train, iteration=iteration, batched_arcs=batched_arcs) if iteration % self.train_controller_every_iter == 0: get_ddp_attr(self.controller, 'train_controller')( G=self.G, z=self.z_train, y=self.y_train, controller=self.controller, controller_optim=self.controller_optim, iteration=iteration, pbar=pbar) # Just for monitoring the training processing sampled_arc = get_ddp_attr(self.controller, 'get_sampled_arc')() sampled_arc = self.get_tensor_of_main_processing(sampled_arc) classes_arcs = sampled_arc[[ 0, ], ].repeat(self.n_classes, 1) self.evaluate_model(classes_arcs=classes_arcs, iteration=iteration) comm.synchronize()
def get_fixed_arc(self, fixed_arc_file, fixed_epoch): self.logger.info( f'\n\tUsing fixed_arc: {fixed_arc_file}, \n\tfixed_epoch: {fixed_epoch}' ) if os.path.isfile(fixed_arc_file): n_classes = self.n_classes if fixed_epoch < 0: with open(fixed_arc_file) as f: sample_arc = [] for _ in range(n_classes): class_arc = f.readline().strip('[\n ]') sample_arc.append( np.fromstring(class_arc, dtype=int, sep=' ')) else: with open(fixed_arc_file) as f: while True: epoch_str = f.readline().strip(': \n') sample_arc = [] for _ in range(n_classes): class_arc = f.readline().strip('[\n ]') sample_arc.append( np.fromstring(class_arc, dtype=int, sep=' ')) if fixed_epoch == int(epoch_str): break sample_arc = np.array(sample_arc) elif fixed_arc_file == 'random': while True: sample_arc = np.random.randint( 0, len(get_ddp_attr(self.G, 'ops')), (self.n_classes, get_ddp_attr(self.G, 'num_layers')), dtype=int) if not self.is_repetitive_between_classes( sample_arc=sample_arc): break elif fixed_arc_file in get_ddp_attr(self.G, 'cfg_ops'): # single op ops = list(get_ddp_attr(self.G, 'cfg_ops').keys()) op_idx = ops.index(fixed_arc_file) sample_arc = \ np.ones((self.n_classes, get_ddp_attr(self.G, 'num_layers')), dtype=int) * op_idx else: raise NotImplemented self.logger.info(f'Sample_arc: \n{sample_arc}') self.revise_num_parameters(fixed_arc=sample_arc) fixed_arc = torch.from_numpy(sample_arc).cuda() return fixed_arc
def train_func(self, data, iteration, pbar): # images, labels = self.preprocess_image(data) # images = images.tensor get_ddp_attr(self.controller, 'train_controller')( G=self.G, z=self.z_train, y=self.y_train, controller=self.controller, controller_optim=self.controller_optim, iteration=iteration, pbar=pbar) # Just for monitoring the training processing # sync arcs sampled_arc = get_ddp_attr(self.controller, 'get_sampled_arc')() sampled_arc = self.get_tensor_of_main_processing(sampled_arc) classes_arcs = sampled_arc[[ 0, ], ].repeat(self.n_classes, 1) self.evaluate_model(classes_arcs=classes_arcs, iteration=iteration) comm.synchronize()
def __init__(self, cfg, **kwargs): super(TrainerDenseGANRetrain, self).__init__(cfg=cfg, **kwargs) num_layers = get_ddp_attr(self.G, 'num_layers') self.arcs = torch.zeros((1, num_layers), dtype=torch.int64).cuda() self.G_ema = copy.deepcopy(self.G) self.ema = EMA(source=self.G, target=self.G_ema, decay=0.999, start_itr=0) self.models.update({'G_ema': self.G_ema}) pass
def __init__(self, cfg, **kwargs): super(TrainerDenseGANEvaluate, self).__init__(cfg=cfg, **kwargs) self.ckpt_dir = get_attr_kwargs(cfg.trainer, 'ckpt_dir', **kwargs) self.ckpt_epoch = get_attr_kwargs(cfg.trainer, 'ckpt_epoch', **kwargs) self.iter_every_epoch = get_attr_kwargs(cfg.trainer, 'iter_every_epoch', **kwargs) self.G_ema = copy.deepcopy(self.G) self.models.update({'G_ema': self.G_ema}) eval_ckpt = self._get_ckpt_path(ckpt_dir=self.ckpt_dir, ckpt_epoch=self.ckpt_epoch, iter_every_epoch=self.iter_every_epoch) self._load_G(eval_ckpt) num_layers = get_ddp_attr(self.G, 'num_layers') self.arcs = torch.zeros((1, num_layers), dtype=torch.int64).cuda() pass
def train_controller(self, searched_cnn, valid_dataset_iter, preprocess_image_func, controller, controller_optim, iteration, pbar): """ :param controller: for ddp training :return: """ if comm.is_main_process() and iteration % 1000 == 0: pbar.set_postfix_str("ClsControllerRLAlphaFair") meter_dict = {} controller.train() controller.zero_grad() sampled_arcs = controller() sample_entropy = get_ddp_attr(controller, 'sample_entropy') sample_log_prob = get_ddp_attr(controller, 'sample_log_prob') val_data = next(valid_dataset_iter) bs = len(val_data) batched_arcs = sampled_arcs.repeat(bs, 1) top1 = AverageMeter() for i in range(self.num_aggregate): val_data = next(valid_dataset_iter) val_X, val_y = preprocess_image_func(val_data, device=self.device) val_X = val_X.tensor with torch.set_grad_enabled(False): logits = searched_cnn(val_X, batched_arcs=batched_arcs) prec1, = top_accuracy(output=logits, target=val_y, topk=(1, )) top1.update(prec1.item(), bs) reward_g = top1.avg meter_dict['reward_g'] = reward_g # detach to make sure that gradients aren't backpropped through the reward reward = torch.tensor(reward_g).cuda() sample_entropy_mean = sample_entropy.mean() meter_dict['sample_entropy'] = sample_entropy_mean.item() reward += self.entropy_weight * sample_entropy_mean if self.baseline is None: baseline = torch.tensor(reward_g) else: baseline = self.baseline - (1 - self.bl_dec) * (self.baseline - reward) # detach to make sure that gradients are not backpropped through the baseline baseline = baseline.detach() sample_log_prob_mean = sample_log_prob.mean() meter_dict['sample_log_prob'] = sample_log_prob_mean.item() loss = -1 * sample_log_prob_mean * (reward - baseline) meter_dict['reward'] = reward.item() meter_dict['baseline'] = baseline.item() meter_dict['loss'] = loss.item() loss.backward(retain_graph=False) grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), self.child_grad_bound) meter_dict['grad_norm'] = grad_norm controller_optim.step() baseline_list = comm.all_gather(baseline) baseline_mean = sum(map(lambda v: v.item(), baseline_list)) / len(baseline_list) baseline.fill_(baseline_mean) self.baseline = baseline if iteration % self.log_every_iter == 0: self.print_distribution(iteration=iteration, log_prob=False, print_interval=10) default_dicts = collections.defaultdict(dict) for meter_k, meter in meter_dict.items(): if meter_k in ['reward', 'baseline']: default_dicts['reward_baseline'][meter_k] = meter else: default_dicts[meter_k][meter_k] = meter summary_defaultdict2txtfig(default_dict=default_dicts, prefix='train_controller', step=iteration, textlogger=self.myargs.textlogger) comm.synchronize() return
def train_controller(self, G, z, y, controller, controller_optim, iteration, pbar): """ :param controller: for ddp training :return: """ if comm.is_main_process() and iteration % 1000 == 0: pbar.set_postfix_str("ControllerRLAlpha") meter_dict = {} G.eval() controller.train() controller.zero_grad() sampled_arcs = controller(iteration) sample_entropy = get_ddp_attr(controller, 'sample_entropy') sample_log_prob = get_ddp_attr(controller, 'sample_log_prob') pool_list, logits_list = [], [] for i in range(self.num_aggregate): z_samples = z.sample().to(self.device) y_samples = y.sample().to(self.device) with torch.set_grad_enabled(False): batched_arcs = sampled_arcs[y_samples] x = G(z=z_samples, y=y_samples, batched_arcs=batched_arcs) pool, logits = self.FID_IS.get_pool_and_logits(x) # pool_list.append(pool) logits_list.append(logits) # pool = np.concatenate(pool_list, 0) logits = np.concatenate(logits_list, 0) reward_g, _ = self.FID_IS.calculate_IS(logits) meter_dict['reward_g'] = reward_g # detach to make sure that gradients aren't backpropped through the reward reward = torch.tensor(reward_g).cuda() sample_entropy_mean = sample_entropy.mean() meter_dict['sample_entropy'] = sample_entropy_mean.item() reward += self.entropy_weight * sample_entropy_mean if self.baseline is None: baseline = torch.tensor(reward_g) else: baseline = self.baseline - (1 - self.bl_dec) * (self.baseline - reward) # detach to make sure that gradients are not backpropped through the baseline baseline = baseline.detach() sample_log_prob_mean = sample_log_prob.mean() meter_dict['sample_log_prob'] = sample_log_prob_mean.item() loss = -1 * sample_log_prob_mean * (reward - baseline) meter_dict['reward'] = reward.item() meter_dict['baseline'] = baseline.item() meter_dict['loss'] = loss.item() loss.backward(retain_graph=False) grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), self.child_grad_bound) meter_dict['grad_norm'] = grad_norm controller_optim.step() baseline_list = comm.all_gather(baseline) baseline_mean = sum(map(lambda v: v.item(), baseline_list)) / len(baseline_list) baseline.fill_(baseline_mean) self.baseline = baseline if iteration % self.log_every_iter == 0: self.print_distribution(iteration=iteration, print_interval=10) if len(sampled_arcs) <= 10: self.logger.info('\nsampled arcs: \n%s' % sampled_arcs.cpu().numpy()) self.myargs.textlogger.logstr( iteration, sampled_arcs='\n' + np.array2string(sampled_arcs.cpu().numpy(), threshold=np.inf)) default_dicts = collections.defaultdict(dict) for meter_k, meter in meter_dict.items(): if meter_k in ['reward', 'baseline']: default_dicts['reward_baseline'][meter_k] = meter else: default_dicts[meter_k][meter_k] = meter summary_defaultdict2txtfig(default_dict=default_dicts, prefix='train_controller', step=iteration, textlogger=self.myargs.textlogger) comm.synchronize() return