Пример #1
0
    def print_distribution(self,
                           iteration,
                           log_prob=True,
                           print_interval=float('inf')):
        # remove formats
        org_formatters = []
        for handler in self.logger.handlers:
            org_formatters.append(handler.formatter)
            handler.setFormatter(logging.Formatter("%(message)s"))

        default_dict = collections.defaultdict(dict)
        self.logger.info("####### distribution #######")
        searched_arc = []
        for layer_id, op_dist in enumerate(self.op_dist):
            prob = op_dist.probs
            max_op_id = prob.argmax().item()
            searched_arc.append(max_op_id)
            for op_id, op_name in enumerate(self.cfg_ops.keys()):
                op_prob = prob[0][op_id]
                default_dict[f'L{layer_id}'][get_prefix_abb(
                    op_name)] = op_prob.item()

            if log_prob:
                if layer_id % print_interval == 0:
                    self.logger.info(layer_id // print_interval)
                self.logger.info(prob)

        searched_arc = np.array(searched_arc)
        self.logger.info('\nsearched arcs: \n%s' % searched_arc)
        self.myargs.textlogger.logstr(
            iteration,
            searched_arc='\n' +
            np.array2string(searched_arc, threshold=np.inf))

        summary_defaultdict2txtfig(default_dict=default_dict,
                                   prefix='',
                                   step=iteration,
                                   textlogger=self.myargs.textlogger)
        self.logger.info("#####################")

        # restore formats
        for handler, formatter in zip(self.logger.handlers, org_formatters):
            handler.setFormatter(formatter)
Пример #2
0
    def print_distribution(self, iteration, print_interval=float('inf')):
        # remove formats
        org_formatters = []
        for handler in self.logger.handlers:
            org_formatters.append(handler.formatter)
            handler.setFormatter(logging.Formatter("%(message)s"))

        self.logger.info("####### distribution #######")
        class_arcs = []
        for class_idx in range(self.n_classes):
            default_dict = collections.defaultdict(dict)
            searched_arc = []
            for layer_id, op_dist in enumerate(self.op_dist):
                prob = op_dist.probs
                max_op_id = prob[class_idx].argmax().item()
                searched_arc.append(max_op_id)
                for op_id, op_name in enumerate(self.cfg_ops.keys()):
                    op_prob = prob[class_idx][op_id]
                    default_dict[f'C{class_idx}L{layer_id}'][get_prefix_abb(
                        op_name)] = op_prob.item()

            class_arcs.append(searched_arc)
            searched_arc = np.array(searched_arc)
            self.logger.info(
                f'Class {class_idx} searched arcs: {searched_arc}')
            # self.myargs.textlogger.logstr(iteration,
            #                               searched_arc='\n' + np.array2string(searched_arc, threshold=np.inf))

            summary_defaultdict2txtfig(default_dict=default_dict,
                                       prefix='',
                                       step=iteration,
                                       textlogger=self.myargs.textlogger)

        class_arcs = np.array(class_arcs)
        self.myargs.textlogger.logstr(
            iteration,
            searched_class_arc='\n' +
            np.array2string(class_arcs, threshold=np.inf))
        self.logger.info("#####################")
        # restore formats
        for handler, formatter in zip(self.logger.handlers, org_formatters):
            handler.setFormatter(formatter)
Пример #3
0
    def train_controller(self, searched_cnn, valid_dataset_iter,
                         preprocess_image_func, controller, controller_optim,
                         iteration, pbar):
        """

    :param controller: for ddp training
    :return:
    """
        if comm.is_main_process() and iteration % 1000 == 0:
            pbar.set_postfix_str("ClsControllerRLAlphaFair")

        meter_dict = {}

        controller.train()
        controller.zero_grad()

        sampled_arcs = controller()
        sample_entropy = get_ddp_attr(controller, 'sample_entropy')
        sample_log_prob = get_ddp_attr(controller, 'sample_log_prob')

        val_data = next(valid_dataset_iter)
        bs = len(val_data)
        batched_arcs = sampled_arcs.repeat(bs, 1)

        top1 = AverageMeter()
        for i in range(self.num_aggregate):
            val_data = next(valid_dataset_iter)
            val_X, val_y = preprocess_image_func(val_data, device=self.device)
            val_X = val_X.tensor
            with torch.set_grad_enabled(False):
                logits = searched_cnn(val_X, batched_arcs=batched_arcs)
                prec1, = top_accuracy(output=logits, target=val_y, topk=(1, ))
                top1.update(prec1.item(), bs)

        reward_g = top1.avg
        meter_dict['reward_g'] = reward_g

        # detach to make sure that gradients aren't backpropped through the reward
        reward = torch.tensor(reward_g).cuda()
        sample_entropy_mean = sample_entropy.mean()
        meter_dict['sample_entropy'] = sample_entropy_mean.item()
        reward += self.entropy_weight * sample_entropy_mean

        if self.baseline is None:
            baseline = torch.tensor(reward_g)
        else:
            baseline = self.baseline - (1 - self.bl_dec) * (self.baseline -
                                                            reward)
            # detach to make sure that gradients are not backpropped through the baseline
            baseline = baseline.detach()

        sample_log_prob_mean = sample_log_prob.mean()
        meter_dict['sample_log_prob'] = sample_log_prob_mean.item()
        loss = -1 * sample_log_prob_mean * (reward - baseline)

        meter_dict['reward'] = reward.item()
        meter_dict['baseline'] = baseline.item()
        meter_dict['loss'] = loss.item()

        loss.backward(retain_graph=False)

        grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(),
                                                   self.child_grad_bound)
        meter_dict['grad_norm'] = grad_norm

        controller_optim.step()

        baseline_list = comm.all_gather(baseline)
        baseline_mean = sum(map(lambda v: v.item(),
                                baseline_list)) / len(baseline_list)
        baseline.fill_(baseline_mean)
        self.baseline = baseline

        if iteration % self.log_every_iter == 0:
            self.print_distribution(iteration=iteration,
                                    log_prob=False,
                                    print_interval=10)
            default_dicts = collections.defaultdict(dict)
            for meter_k, meter in meter_dict.items():
                if meter_k in ['reward', 'baseline']:
                    default_dicts['reward_baseline'][meter_k] = meter
                else:
                    default_dicts[meter_k][meter_k] = meter
            summary_defaultdict2txtfig(default_dict=default_dicts,
                                       prefix='train_controller',
                                       step=iteration,
                                       textlogger=self.myargs.textlogger)
        comm.synchronize()
        return
Пример #4
0
def main(args, myargs):
  # arguments


  # create model and ship to GPU
  model = VAE(args).cuda()
  print(model)

  # reproducibility is da best
  set_seed(0)

  opt = torch.optim.Adamax(model.parameters(), lr=args.lr)

  # create datasets / dataloaders
  scale_inv = lambda x: x + 0.5
  ds_transforms = transforms.Compose([transforms.ToTensor(), lambda x: x - 0.5])
  kwargs = {'num_workers': args.num_workers, 'pin_memory': True, 'drop_last': True}

  train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.dataset_dir, train=True,
                                                              download=True, transform=ds_transforms),
                                             batch_size=args.batch_size, shuffle=True, **kwargs)

  test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.dataset_dir, train=False,
                                                             download=True, transform=ds_transforms),
                                            batch_size=args.batch_size, shuffle=False, **{**kwargs, 'drop_last': False})

  # spawn writer
  model_name = 'NB{}_D{}_Z{}_H{}_BS{}_FB{}_LR{}_IAF{}'.format(args.n_blocks, args.depth, args.z_size, args.h_size,
                                                              args.batch_size, args.free_bits, args.lr, args.iaf)

  model_name = 'test' if args.debug else model_name
  # log_dir = join('runs', model_name)
  # sample_dir = join(log_dir, 'samples')
  # writer = SummaryWriter(log_dir=log_dir)
  log_dir = f'{myargs.args.outdir}/iaf'
  os.makedirs(log_dir, exist_ok=True)
  sample_dir = myargs.args.imgdir
  writer = myargs.writer
  maybe_create_dir(sample_dir)

  print_and_save_args(args, log_dir)
  print('logging into %s' % log_dir)
  maybe_create_dir(sample_dir)
  best_test = float('inf')

  print('starting training')
  kl_meter = AverageMeter()
  bpd_meter = AverageMeter()
  elbo_meter = AverageMeter()
  kl_obj_meter = AverageMeter()
  log_pxz_meter = AverageMeter()

  for epoch in range(args.n_epochs):
    print(f'\nEpoch [{epoch}/{args.n_epochs}]')
    model.train()
    train_log = reset_log()
    kl_meter.reset()
    bpd_meter.reset()
    elbo_meter.reset()
    kl_obj_meter.reset()
    log_pxz_meter.reset()
    summary_d = collections.defaultdict(dict)

    for batch_idx, (input, _) in enumerate(tqdm.tqdm(train_loader, file=myargs.stdout,
                                                     desc=myargs.args.time_str_suffix)):
      if args.train_dummy and batch_idx != 0:
        break
      input = input.cuda()
      x, kl, kl_obj = model(input)

      log_pxz = logistic_ll(x, model.dec_log_stdv, sample=input)
      loss = (kl_obj - log_pxz).sum() / x.size(0)
      elbo = (kl - log_pxz)
      bpd = elbo / (32 * 32 * 3 * np.log(2.))

      opt.zero_grad()
      loss.backward()
      opt.step()

      # train_log['kl'] += [kl.mean()]
      # train_log['bpd'] += [bpd.mean()]
      # train_log['elbo'] += [elbo.mean()]
      # train_log['kl obj'] += [kl_obj.mean()]
      # train_log['log p(x|z)'] += [log_pxz.mean()]
      # for key, value in train_log.items():
      #   print_and_log_scalar(writer, 'train/%s' % key, value, epoch)
      # print()

      kl_meter.update(kl.mean().item())
      bpd_meter.update(bpd.mean().item())
      elbo_meter.update(elbo.mean().item())
      kl_obj_meter.update(kl_obj.mean().item())
      log_pxz_meter.update(log_pxz.mean().item())

    summary_d['kl_meter']['train'] = kl_meter.avg
    summary_d['bpd_meter']['train'] = bpd_meter.avg
    summary_d['elbo_meter']['train'] = elbo_meter.avg
    summary_d['kl_obj_meter']['train'] = kl_obj_meter.avg
    summary_d['log_pxz_meter']['train'] = log_pxz_meter.avg

    model.eval()
    test_log = reset_log()
    kl_meter.reset()
    bpd_meter.reset()
    elbo_meter.reset()
    kl_obj_meter.reset()
    log_pxz_meter.reset()

    with torch.no_grad():
      for batch_idx, (input, _) in enumerate(tqdm.tqdm(test_loader, file=myargs.stdout)):
        input = input.cuda()
        x, kl, kl_obj = model(input)

        log_pxz = logistic_ll(x, model.dec_log_stdv, sample=input)
        loss = (kl_obj - log_pxz).sum() / x.size(0)
        elbo = (kl - log_pxz)
        bpd = elbo / (32 * 32 * 3 * np.log(2.))

        # test_log['kl'] += [kl.mean()]
        # test_log['bpd'] += [bpd.mean()]
        # test_log['elbo'] += [elbo.mean()]
        # test_log['kl obj'] += [kl_obj.mean()]
        # test_log['log p(x|z)'] += [log_pxz.mean()]

        kl_meter.update(kl.mean().item())
        bpd_meter.update(bpd.mean().item())
        elbo_meter.update(elbo.mean().item())
        kl_obj_meter.update(kl_obj.mean().item())
        log_pxz_meter.update(log_pxz.mean().item())

      summary_d['kl_meter']['test'] = kl_meter.avg
      summary_d['bpd_meter']['test'] = bpd_meter.avg
      summary_d['elbo_meter']['test'] = elbo_meter.avg
      summary_d['kl_obj_meter']['test'] = kl_obj_meter.avg
      summary_d['log_pxz_meter']['test'] = log_pxz_meter.avg

      summary_defaultdict2txtfig(summary_d, prefix='', step=epoch,
                                 textlogger=myargs.textlogger, save_fig_sec=60)

      all_samples = model.cond_sample(input)
      # save reconstructions
      out = torch.stack((x, input))  # 2, bs, 3, 32, 32
      out = out.transpose(1, 0).contiguous()  # bs, 2, 3, 32, 32
      out = out.view(-1, x.size(-3), x.size(-2), x.size(-1))

      all_samples += [x]
      all_samples = torch.stack(all_samples)  # L, bs, 3, 32, 32
      all_samples = all_samples.transpose(1, 0)
      all_samples = all_samples.contiguous()  # bs, L, 3, 32, 32
      all_samples = all_samples.view(-1, x.size(-3), x.size(-2), x.size(-1))

      save_image(scale_inv(all_samples), join(sample_dir, 'test_levels_{}.png'.format(epoch)), nrow=12)
      save_image(scale_inv(out), join(sample_dir, 'test_recon_{}.png'.format(epoch)), nrow=12)
      save_image(scale_inv(model.sample(64)), join(sample_dir, 'sample_{}.png'.format(epoch)), nrow=8)

    # for key, value in test_log.items():
    #   print_and_log_scalar(writer, 'test/%s' % key, value, epoch)
    # print()

    # current_test = sum(test_log['bpd']) / batch_idx
    current_test = bpd_meter.avg

    if current_test < best_test:
      best_test = current_test
      print('saving best model')
      torch.save(model.state_dict(), join(log_dir, 'best_model.pth'))
Пример #5
0
    def train_controller(self, G, z, y, controller, controller_optim,
                         iteration, pbar):
        """

    :param controller: for ddp training
    :return:
    """
        if comm.is_main_process() and iteration % 1000 == 0:
            pbar.set_postfix_str("ControllerRLAlpha")

        meter_dict = {}

        G.eval()
        controller.train()

        controller.zero_grad()

        sampled_arcs = controller(iteration)

        sample_entropy = get_ddp_attr(controller, 'sample_entropy')
        sample_log_prob = get_ddp_attr(controller, 'sample_log_prob')

        pool_list, logits_list = [], []
        for i in range(self.num_aggregate):
            z_samples = z.sample().to(self.device)
            y_samples = y.sample().to(self.device)
            with torch.set_grad_enabled(False):
                batched_arcs = sampled_arcs[y_samples]
                x = G(z=z_samples, y=y_samples, batched_arcs=batched_arcs)

            pool, logits = self.FID_IS.get_pool_and_logits(x)

            # pool_list.append(pool)
            logits_list.append(logits)

        # pool = np.concatenate(pool_list, 0)
        logits = np.concatenate(logits_list, 0)

        reward_g, _ = self.FID_IS.calculate_IS(logits)
        meter_dict['reward_g'] = reward_g

        # detach to make sure that gradients aren't backpropped through the reward
        reward = torch.tensor(reward_g).cuda()
        sample_entropy_mean = sample_entropy.mean()
        meter_dict['sample_entropy'] = sample_entropy_mean.item()
        reward += self.entropy_weight * sample_entropy_mean

        if self.baseline is None:
            baseline = torch.tensor(reward_g)
        else:
            baseline = self.baseline - (1 - self.bl_dec) * (self.baseline -
                                                            reward)
            # detach to make sure that gradients are not backpropped through the baseline
            baseline = baseline.detach()

        sample_log_prob_mean = sample_log_prob.mean()
        meter_dict['sample_log_prob'] = sample_log_prob_mean.item()
        loss = -1 * sample_log_prob_mean * (reward - baseline)

        meter_dict['reward'] = reward.item()
        meter_dict['baseline'] = baseline.item()
        meter_dict['loss'] = loss.item()

        loss.backward(retain_graph=False)

        grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(),
                                                   self.child_grad_bound)
        meter_dict['grad_norm'] = grad_norm

        controller_optim.step()

        baseline_list = comm.all_gather(baseline)
        baseline_mean = sum(map(lambda v: v.item(),
                                baseline_list)) / len(baseline_list)
        baseline.fill_(baseline_mean)
        self.baseline = baseline

        if iteration % self.log_every_iter == 0:
            self.print_distribution(iteration=iteration, print_interval=10)
            if len(sampled_arcs) <= 10:
                self.logger.info('\nsampled arcs: \n%s' %
                                 sampled_arcs.cpu().numpy())
            self.myargs.textlogger.logstr(
                iteration,
                sampled_arcs='\n' +
                np.array2string(sampled_arcs.cpu().numpy(), threshold=np.inf))
            default_dicts = collections.defaultdict(dict)
            for meter_k, meter in meter_dict.items():
                if meter_k in ['reward', 'baseline']:
                    default_dicts['reward_baseline'][meter_k] = meter
                else:
                    default_dicts[meter_k][meter_k] = meter
            summary_defaultdict2txtfig(default_dict=default_dicts,
                                       prefix='train_controller',
                                       step=iteration,
                                       textlogger=self.myargs.textlogger)
        comm.synchronize()
        return