Exemple #1
0
class Trainer(object):
    """Training network parameters and theta separately.
  """
    def __init__(self,
                 network,
                 w_lr=0.01,
                 w_mom=0.9,
                 w_wd=1e-4,
                 t_lr=0.001,
                 t_wd=3e-3,
                 t_beta=(0.5, 0.999),
                 init_temperature=5.0,
                 temperature_decay=0.965,
                 logger=logging,
                 lr_scheduler={'T_max': 200},
                 gpus=[0],
                 save_theta_prefix=''):
        assert isinstance(network, FBNet)
        network.apply(weights_init)
        network = network.train().cuda()
        if isinstance(gpus, str):
            gpus = [int(i) for i in gpus.strip().split(',')]
        network = DataParallel(network, gpus)
        self.gpus = gpus
        self._mod = network
        theta_params = network.theta
        mod_params = network.parameters()
        self.theta = theta_params
        self.w = mod_params
        self._tem_decay = temperature_decay
        self.temp = init_temperature
        self.logger = logger
        self.save_theta_prefix = save_theta_prefix

        self._acc_avg = AvgrageMeter('acc')
        self._ce_avg = AvgrageMeter('ce')
        self._lat_avg = AvgrageMeter('lat')
        self._loss_avg = AvgrageMeter('loss')

        self.w_opt = torch.optim.SGD(mod_params,
                                     w_lr,
                                     momentum=w_mom,
                                     weight_decay=w_wd)

        self.w_sche = CosineDecayLR(self.w_opt, **lr_scheduler)

        self.t_opt = torch.optim.Adam(theta_params,
                                      lr=t_lr,
                                      betas=t_beta,
                                      weight_decay=t_wd)

    def train_w(self, input, target, decay_temperature=False):
        """Update model parameters.
    """
        self.w_opt.zero_grad()
        loss, ce, lat, acc, energy = self._mod(input, target, self.temp)
        loss.backward()
        self.w_opt.step()
        if decay_temperature:
            tmp = self.temp
            self.temp *= self._tem_decay
            self.logger.info("Change temperature from %.5f to %.5f" %
                             (tmp, self.temp))
        return loss.item(), ce.item(), lat.item(), acc.item(), energy.item()

    def train_t(self, input, target, decay_temperature=False):
        """Update theta.
    """
        self.t_opt.zero_grad()
        loss, ce, lat, acc, energy = self._mod(input, target, self.temp)
        loss.backward()
        self.t_opt.step()
        if decay_temperature:
            tmp = self.temp
            self.temp *= self._tem_decay
            self.logger.info("Change temperature from %.5f to %.5f" %
                             (tmp, self.temp))
        return loss.item(), ce.item(), lat.item(), acc.item(), energy.item()

    def decay_temperature(self, decay_ratio=None):
        tmp = self.temp
        if decay_ratio is None:
            self.temp *= self._tem_decay
        else:
            self.temp *= decay_ratio
        self.logger.info("Change temperature from %.5f to %.5f" %
                         (tmp, self.temp))

    def _step(self, input, target, epoch, step, log_frequence, func):
        """Perform one step of training.
    """
        input = input.cuda()
        target = target.cuda()
        loss, ce, lat, acc, energy = func(input, target)

        # Get status
        batch_size = self._mod.batch_size

        self._acc_avg.update(acc)
        self._ce_avg.update(ce)
        self._lat_avg.update(lat)
        self._loss_avg.update(loss)

        if step > 1 and (step % log_frequence == 0):
            self.toc = time.time()
            speed = 1.0 * (batch_size * log_frequence) / (self.toc - self.tic)

            self.logger.info(
                "Epoch[%d] Batch[%d] Speed: %.6f samples/sec %s %s %s %s" %
                (epoch, step, speed, self._loss_avg, self._acc_avg,
                 self._ce_avg, self._lat_avg))
            map(lambda avg: avg.reset(),
                [self._loss_avg, self._acc_avg, self._ce_avg, self._lat_avg])
            self.tic = time.time()

    def search(self,
               train_w_ds,
               train_t_ds,
               total_epoch=90,
               start_w_epoch=10,
               log_frequence=100):
        """Search model.
    """
        assert start_w_epoch >= 1, "Start to train w"
        self.tic = time.time()
        for epoch in range(start_w_epoch):
            self.logger.info("Start to train w for epoch %d" % epoch)
            for step, (input, target) in enumerate(train_w_ds):
                self._step(input, target, epoch, step, log_frequence,
                           lambda x, y: self.train_w(x, y, False))
                self.w_sche.step()
                # print(self.w_sche.last_epoch, self.w_opt.param_groups[0]['lr'])

        self.tic = time.time()
        for epoch in range(total_epoch):
            self.logger.info("Start to train theta for epoch %d" %
                             (epoch + start_w_epoch))
            for step, (input, target) in enumerate(train_t_ds):
                self._step(input, target, epoch + start_w_epoch, step,
                           log_frequence,
                           lambda x, y: self.train_t(x, y, False))
                self.save_theta(
                    './theta-result/%s_theta_epoch_%d.txt' %
                    (self.save_theta_prefix, epoch + start_w_epoch))
            self.decay_temperature()
            self.logger.info("Start to train w for epoch %d" %
                             (epoch + start_w_epoch))
            for step, (input, target) in enumerate(train_w_ds):
                self._step(input, target, epoch + start_w_epoch, step,
                           log_frequence,
                           lambda x, y: self.train_w(x, y, False))
                self.w_sche.step()

    def save_theta(self, save_path='theta.txt'):
        """Save theta.
    """
        res = []
        with open(save_path, 'w') as f:
            for t in self.theta:
                t_list = list(t.detach().cpu().numpy())
                res.append(t_list)
                s = ' '.join([str(tmp) for tmp in t_list])
                f.write(s + '\n')
        return res
Exemple #2
0
class Trainer(object):
  """Training network parameters and alpha.
  """
  def __init__(self, network,
               w_lr=0.01,
               w_mom=0.9,
               w_wd=1e-4,
               t_lr=0.001,
               t_wd=3e-3,
               t_beta=(0.5, 0.999),
               init_temperature=5.0,
               temperature_decay=0.965,
               logger=logging,
               lr_scheduler={'T_max' : 200},
               gpus=[0],
               save_theta_prefix='',
               resource_weight=0.001):
    assert isinstance(network, SNAS)
    network.apply(weights_init)
    network = network.train().cuda()
    self._criterion = nn.CrossEntropyLoss().cuda()

    alpha_params = network.arch_parameters()
    mod_params = network.model_parameters()
    self.alpha = alpha_params
    if isinstance(gpus, str):
      gpus = [int(i) for i in gpus.strip().split(',')]
    network = DataParallel(network, gpus)
    self._mod = network
    self.gpus = gpus

    self.w = mod_params
    self._tem_decay = temperature_decay
    self.temp = init_temperature
    self.logger = logger
    self.save_theta_prefix = save_theta_prefix
    self._resource_weight = resource_weight

    self._loss_avg = AvgrageMeter('loss')
    self._acc_avg = AvgrageMeter('acc')
    self._res_cons_avg = AvgrageMeter('resource-constraint')

    self.w_opt = torch.optim.SGD(
                    mod_params,
                    w_lr,
                    momentum=w_mom,
                    weight_decay=w_wd)
    self.w_sche = CosineDecayLR(self.w_opt, **lr_scheduler)
    self.t_opt = torch.optim.Adam(
                    alpha_params,
                    lr=t_lr, betas=t_beta,
                    weight_decay=t_wd)

  def _acc(self, logits, target):
    batch_size = target.size()[0]
    pred = torch.argmax(logits, dim=1)
    acc = torch.sum(pred == target).float() / batch_size
    return acc

  def train(self, input, target):
    """Update parameters.
    """
    self.w_opt.zero_grad()
    logits, costs = self._mod(input, self.temp)
    acc = self._acc(logits, target)
    costs = costs.mean()
    costs *= self._resource_weight
    loss = self._criterion(logits, target) + costs
    loss.backward()
    self.w_opt.step()
    self.t_opt.step()
    return loss, costs, acc
  
  def decay_temperature(self, decay_ratio=None):
    tmp = self.temp
    if decay_ratio is None:
      self.temp *= self._tem_decay
    else:
      self.temp *= decay_ratio
    self.logger.info("Change temperature from %.5f to %.5f" % (tmp, self.temp))
  
  def _step(self, input, target, 
            epoch, step,
            log_frequence,
            func):
    """Perform one step of training.
    """
    input = input.cuda()
    target = target.cuda()
    loss, res_cost, acc = func(input, target)

    # Get status
    batch_size = input.size()[0]

    self._loss_avg.update(loss)
    self._res_cons_avg.update(res_cost)
    self._acc_avg.update(acc)

    if step > 1 and (step % log_frequence == 0):
      self.toc = time.time()
      speed = 1.0 * (batch_size * log_frequence) / (self.toc - self.tic)

      self.logger.info("Epoch[%d] Batch[%d] Speed: %.6f samples/sec %s %s %s" 
              % (epoch, step, speed, self._loss_avg, self._acc_avg,
                 self._res_cons_avg))
      map(lambda avg: avg.reset(), [self._loss_avg, self._res_cons_avg,
                                    self._acc_avg])
      self.tic = time.time()
  
  def search(self, train_ds,
            epochs=90,
            log_frequence=100):
    """Search model.
    """

    self.tic = time.time()
    for epoch in range(epochs):
      self.logger.info("Start to train for epoch %d" % (epoch))
      for step, (input, target) in enumerate(train_ds):
        self._step(input, target, epoch, 
                   step, log_frequence,
                   lambda x, y: self.train(x, y))
      self.save_alpha('./alpha-result/%s_theta_epoch_%d.txt' % 
                  (self.save_theta_prefix, epoch))
      self.decay_temperature()
      self.w_sche.step()

  def save_alpha(self, save_path='alpha.txt'):
    """Save alpha.
    """
    res = []
    with open(save_path, 'w') as f:
      for i, t in enumerate(self.alpha):
        n = 'normal' if i == 0 else 'reduce'
        assert i <= 1
        tmp = t.size(0)
        f.write(n + ':' + '\n')
        for j in range(tmp):
          t_list = list(t[j].detach().cpu().numpy())
          res.append(t_list)
          s = ' '.join([str(tmp) for tmp in t_list])
          f.write(s + '\n')
    return res
def train(cfg):
    # Initialize
    init_seeds()
    image_size_min = 6.6  # 320 / 32 / 1.5
    image_size_max = 28.5  # 320 / 32 / 28.5
    if cfg.TRAIN.MULTI_SCALE:
        image_size_min = round(cfg.TRAIN.IMAGE_SIZE / 32 / 1.5)
        image_size_max = round(cfg.TRAIN.IMAGE_SIZE / 32 * 1.5)
        image_size = image_size_max * 32  # initiate with maximum multi_scale size
        print(f"Using multi-scale {image_size_min * 32} - {image_size}")

    # Remove previous results
    for files in glob.glob("results.txt"):
        os.remove(files)

    # Initialize model
    model = YOLOv3(cfg).to(device)

    # Optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=cfg.TRAIN.LR,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.DECAY,
                          nesterov=True)

    # Define the loss function calculation formula of the model
    compute_loss = YoloV3Loss(cfg)

    epoch = 0
    start_epoch = 0
    best_maps = 0.0
    context = None

    # Dataset
    # Apply augmentation hyperparameters
    train_dataset = VocDataset(anno_file_type=cfg.TRAIN.DATASET,
                               image_size=cfg.TRAIN.IMAGE_SIZE,
                               cfg=cfg)
    # Dataloader
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=cfg.TRAIN.MINI_BATCH_SIZE,
                                  num_workers=cfg.TRAIN.WORKERS,
                                  shuffle=cfg.TRAIN.SHUFFLE,
                                  pin_memory=cfg.TRAIN.PIN_MENORY)

    if cfg.TRAIN.WEIGHTS.endswith(".pth"):
        state = torch.load(cfg.TRAIN.WEIGHTS, map_location=device)
        # load model
        try:
            state["state_dict"] = {
                k: v
                for k, v in state["state_dict"].items()
                if model.state_dict()[k].numel() == v.numel()
            }
            model.load_state_dict(state["state_dict"], strict=False)
        except KeyError as e:
            error_msg = f"{cfg.TRAIN.WEIGHTS} is not compatible with {cfg.CONFIG_FILE}. "
            error_msg += f"Specify --weights `` or specify a --config-file "
            error_msg += f"compatible with {cfg.TRAIN.WEIGHTS}. "
            raise KeyError(error_msg) from e

        # load optimizer
        if state["optimizer"] is not None:
            optimizer.load_state_dict(state["optimizer"])
            best_maps = state["best_maps"]

        # load results
        if state.get("training_results") is not None:
            with open("results.txt", "w") as file:
                file.write(state["training_results"])  # write results.txt

        start_epoch = state["batches"] + 1 // len(train_dataloader)
        del state

    elif len(cfg.TRAIN.WEIGHTS) > 0:
        # possible weights are "*.weights", "yolov3-tiny.conv.15",  "darknet53.conv.74" etc.
        load_darknet_weights(model, cfg.TRAIN.WEIGHTS)
    else:
        print("Pre training model weight not loaded.")

    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        # skip print amp info
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

    # source https://arxiv.org/pdf/1812.01187.pdf
    scheduler = CosineDecayLR(optimizer,
                              max_batches=cfg.TRAIN.MAX_BATCHES,
                              lr=cfg.TRAIN.LR,
                              warmup=cfg.TRAIN.WARMUP_BATCHES)

    # Initialize distributed training
    if device.type != "cpu" and torch.cuda.device_count(
    ) > 1 and torch.distributed.is_available():
        dist.init_process_group(
            backend="nccl",  # "distributed backend"
            # distributed training init method
            init_method="tcp://127.0.0.1:9999",
            # number of nodes for distributed training
            world_size=1,
            # distributed training node rank
            rank=0)
        model = torch.nn.parallel.DistributedDataParallel(model)
        model.backbone = model.module.backbone

    # Model EMA
    # TODO: ema = ModelEMA(model, decay=0.9998)

    # Start training
    batches_num = len(train_dataloader)  # number of batches
    # 'loss_GIOU', 'loss_Confidence', 'loss_Classification' 'loss'
    results = (0, 0, 0, 0)
    epochs = cfg.TRAIN.MAX_BATCHES // len(train_dataloader)
    print(f"Using {cfg.TRAIN.WORKERS} dataloader workers.")
    print(
        f"Starting training {cfg.TRAIN.MAX_BATCHES} batches for {epochs} epochs..."
    )

    start_time = time.time()
    for epoch in range(start_epoch, epochs):
        model.train()

        # init batches
        batches = 0
        mean_losses = torch.zeros(4)
        print("\n")
        print(
            ("%10s" * 7) %
            ("Batch", "memory", "GIoU", "conf", "cls", "total", " image_size"))
        progress_bar = tqdm(enumerate(train_dataloader), total=batches_num)
        for index, (images, small_label_bbox, medium_label_bbox,
                    large_label_bbox, small_bbox, medium_bbox,
                    large_bbox) in progress_bar:

            # number integrated batches (since train start)
            batches = index + len(train_dataloader) * epoch

            scheduler.step(batches)

            images = images.to(device)

            small_label_bbox = small_label_bbox.to(device)
            medium_label_bbox = medium_label_bbox.to(device)
            large_label_bbox = large_label_bbox.to(device)

            small_bbox = small_bbox.to(device)
            medium_bbox = medium_bbox.to(device)
            large_bbox = large_bbox.to(device)

            # Hyper parameter Burn-in
            if batches <= cfg.TRAIN.WARMUP_BATCHES:
                for m in model.named_modules():
                    if m[0].endswith('BatchNorm2d'):
                        m[1].track_running_stats = batches == cfg.TRAIN.WARMUP_BATCHES

            # Run model
            pred, raw = model(images)

            # Compute loss
            loss, loss_giou, loss_conf, loss_cls = compute_loss(
                pred, raw, small_label_bbox, medium_label_bbox,
                large_label_bbox, small_bbox, medium_bbox, large_bbox)

            # Compute gradient
            if mixed_precision:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            # Optimize accumulated gradient
            if batches % cfg.TRAIN.BATCH_SIZE // cfg.TRAIN.MINI_BATCH_SIZE == 0:
                optimizer.step()
                optimizer.zero_grad()
                # TODO: ema.update(model)

            # Print batch results
            # update mean losses
            loss_items = torch.tensor([loss_giou, loss_conf, loss_cls, loss])
            mean_losses = (mean_losses * index + loss_items) / (index + 1)
            memory = f"{torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0:.2f}G"
            context = ("%10s" * 2 + "%10.3g" * 5) % (
                "%g/%g" % (batches + 1, cfg.TRAIN.MAX_BATCHES), memory,
                *mean_losses, train_dataset.image_size)
            progress_bar.set_description(context)

            # Multi-Scale training
            if cfg.TRAIN.MULTI_SCALE:
                #  adjust img_size (67% - 150%) every 10 batch size
                if batches % cfg.TRAIN.RESIZE_INTERVAL == 0:
                    train_dataset.image_size = random.randrange(
                        image_size_min, image_size_max + 1) * 32

            # Write Tensorboard results
            if tb_writer:
                # 'loss_GIOU', 'loss_Confidence', 'loss_Classification' 'loss'
                titles = ["GIoU", "Confidence", "Classification", "Train loss"]
                for xi, title in zip(
                        list(mean_losses) + list(results), titles):
                    tb_writer.add_scalar(title, xi, index)

        # Process epoch results
        # TODO: ema.update_attr(model)
        final_epoch = epoch + 1 == epochs

        # Calculate mAP
        # skip first epoch
        maps = 0.
        if epoch > 0:
            maps = evaluate(cfg, args)

        # Write epoch results
        with open("results.txt", "a") as f:
            # 'loss_GIOU', 'loss_Confidence', 'loss_Classification' 'loss', 'maps'
            f.write(context + "%10.3g" * 1 % maps)
            f.write("\n")

        # Update best mAP
        if maps > best_maps:
            best_maps = maps

        # Save training results
        with open("results.txt", 'r') as f:
            # Create checkpoint
            state = {
                'batches': batches,
                'best_maps': maps,
                'training_results': f.read(),
                'state_dict': model.state_dict(),
                'optimizer': None if final_epoch else optimizer.state_dict()
            }

        # Save last checkpoint
        torch.save(state, "weights/checkpoint.pth")

        # Save best checkpoint
        if best_maps == maps:
            state = {
                'batches': -1,
                'best_maps': None,
                'training_results': None,
                'state_dict': model.state_dict(),
                'optimizer': None
            }
            torch.save(state, "weights/model_best.pth")

        # Delete checkpoint
        del state

    print(f"{epoch - start_epoch} epochs completed "
          f"in {(time.time() - start_time) / 3600:.3f} hours.\n")
    dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
    torch.cuda.empty_cache()
Exemple #4
0
class Trainer(object):
    """Training network parameters and theta separately.
  """
    def __init__(self,
                 network,
                 w_lr=0.01,
                 w_mom=0.9,
                 w_wd=1e-4,
                 t_lr=0.001,
                 t_wd=3e-3,
                 t_beta=(0.5, 0.999),
                 init_temperature=5.0,
                 temperature_decay=0.965,
                 logger=logging,
                 lr_scheduler={'T_max': 200},
                 gpus=[0],
                 save_theta_prefix='',
                 save_tb_log=''):
        assert isinstance(network, FBNet)
        network.apply(weights_init)
        network = network.train().cuda()
        if isinstance(gpus, str):
            gpus = [int(i) for i in gpus.strip().split(',')]
        # network = DataParallel(network, gpus)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        network.to(device)
        self.gpus = gpus
        self._mod = network
        theta_params = network.theta
        mod_params = network.parameters()
        self.theta = theta_params
        self.w = mod_params
        self._tem_decay = temperature_decay
        self.temp = init_temperature
        self.logger = logger
        self.tensorboard = Tensorboard('logs/' + save_tb_log)
        self.save_theta_prefix = save_theta_prefix

        self._acc_avg = AvgrageMeter('acc')
        self._ce_avg = AvgrageMeter('ce')
        self._lat_avg = AvgrageMeter('lat')
        self._loss_avg = AvgrageMeter('loss')
        self._ener_avg = AvgrageMeter('ener')

        self.w_opt = torch.optim.SGD(mod_params,
                                     w_lr,
                                     momentum=w_mom,
                                     weight_decay=w_wd)

        self.w_sche = CosineDecayLR(self.w_opt, **lr_scheduler)

        self.t_opt = torch.optim.Adam(theta_params,
                                      lr=t_lr,
                                      betas=t_beta,
                                      weight_decay=t_wd)

    def train_w(self, input, target, decay_temperature=False):
        """Update model parameters.
    """
        self.w_opt.zero_grad()
        loss, ce, lat, acc, ener = self._mod(input, target, self.temp)
        loss.backward()
        self.w_opt.step()
        if decay_temperature:
            tmp = self.temp
            self.temp *= self._tem_decay
            self.logger.info("Change temperature from %.5f to %.5f" %
                             (tmp, self.temp))
        return loss.item(), ce.item(), lat.item(), acc.item(), ener.item()

    def train_t(self, input, target, decay_temperature=False):
        """Update theta.
    """
        self.t_opt.zero_grad()
        loss, ce, lat, acc, ener = self._mod(input, target, self.temp)
        loss.backward()
        self.t_opt.step()
        if decay_temperature:
            tmp = self.temp
            self.temp *= self._tem_decay
            self.logger.info("Change temperature from %.5f to %.5f" %
                             (tmp, self.temp))
        return loss.item(), ce.item(), lat.item(), acc.item(), ener.item()

    def decay_temperature(self, decay_ratio=None):
        tmp = self.temp
        if decay_ratio is None:
            self.temp *= self._tem_decay
        else:
            self.temp *= decay_ratio
        self.logger.info("Change temperature from %.5f to %.5f" %
                         (tmp, self.temp))

    def _step(self, input, target, epoch, step, log_frequence, func):
        """Perform one step of training.
    """
        input = input.cuda()
        target = target.cuda()
        loss, ce, lat, acc, ener = func(input, target)

        # Get status
        batch_size = self._mod.batch_size

        self._acc_avg.update(acc)
        self._ce_avg.update(ce)
        self._lat_avg.update(lat)
        self._loss_avg.update(loss)
        self._ener_avg.update(ener)

        if step > 1 and (step % log_frequence == 0):
            self.toc = time.time()
            speed = 1.0 * (batch_size * log_frequence) / (self.toc - self.tic)
            self.tensorboard.log_scalar('Total Loss',
                                        self._loss_avg.getValue(), step)
            self.tensorboard.log_scalar('Accuracy', self._acc_avg.getValue(),
                                        step)
            self.tensorboard.log_scalar('Latency', self._lat_avg.getValue(),
                                        step)
            self.tensorboard.log_scalar('Energy', self._ener_avg.getValue(),
                                        step)
            self.logger.info(
                "Epoch[%d] Batch[%d] Speed: %.6f samples/sec %s %s %s %s %s" %
                (epoch, step, speed, self._loss_avg, self._acc_avg,
                 self._ce_avg, self._lat_avg, self._ener_avg))
            map(lambda avg: avg.reset(), [
                self._loss_avg, self._acc_avg, self._ce_avg, self._lat_avg,
                self._ener_avg
            ])
            self.tic = time.time()

    def search(self,
               train_w_ds,
               train_t_ds,
               total_epoch=90,
               start_w_epoch=10,
               log_frequence=100):
        """Search model.
    """
        assert start_w_epoch >= 1, "Start to train w"
        self.tic = time.time()
        for epoch in range(start_w_epoch):
            self.logger.info("Start to train w for epoch %d" % epoch)
            for step, (input, target) in enumerate(train_w_ds):
                self._step(input, target, epoch, step, log_frequence,
                           lambda x, y: self.train_w(x, y, False))
                self.w_sche.step()
                self.tensorboard.log_scalar('Learning rate curve',
                                            self.w_sche.last_epoch,
                                            self.w_opt.param_groups[0]['lr'])
                #print(self.w_sche.last_epoch, self.w_opt.param_groups[0]['lr'])

        self.tic = time.time()
        for epoch in range(total_epoch):
            self.logger.info("Start to train theta for epoch %d" %
                             (epoch + start_w_epoch))
            for step, (input, target) in enumerate(train_t_ds):
                self._step(input, target, epoch + start_w_epoch, step,
                           log_frequence,
                           lambda x, y: self.train_t(x, y, False))
                self.save_theta(
                    './theta-result/%s_theta_epoch_%d.txt' %
                    (self.save_theta_prefix, epoch + start_w_epoch), epoch)
            self.decay_temperature()
            self.logger.info("Start to train w for epoch %d" %
                             (epoch + start_w_epoch))
            for step, (input, target) in enumerate(train_w_ds):
                self._step(input, target, epoch + start_w_epoch, step,
                           log_frequence,
                           lambda x, y: self.train_w(x, y, False))
                self.w_sche.step()
            self.tensorboard.close()

    def save_theta(self, save_path='theta.txt', epoch=0):
        """Save theta.
    """
        res = []
        directory = os.path.dirname(save_path)
        if not os.path.exists(directory):
            os.makedirs(directory)
        with open(save_path, 'w') as f:
            for i, t in enumerate(self.theta):
                t_list = list(t.detach().cpu().numpy())
                if (len(t_list) < 9): t_list.append(0.00)
                max_index = t_list.index(max(t_list))
                self.tensorboard.log_scalar('Layer %s' % str(i), max_index + 1,
                                            epoch)
                res.append(t_list)
                s = ' '.join([str(tmp) for tmp in t_list])
                f.write(s + '\n')

            val = np.array(res)
            ax = sns.heatmap(val, cbar=True, annot=True)
            ax.figure.savefig(save_path[:-3] + 'png')
            #self.tensorboard.log_image('Theta Values',val,epoch)
            plt.close()
        return res