示例#1
0
def binary_search(model, target):
    """
    Binary search algorithm to determine the threshold
    :param model:
    :param target:
    :param merge_flag:
    :return:
    """
    # target = 0.70
    # threshold = model.get_model().args.threshold
    step = 0.01
    # step_min = 0.0001
    status = 1.0
    stop = 0.001
    counter = 1
    max_iter = 100
    flops = get_flops(model)
    params = get_parameters(model)

    while abs(status - target) > stop and counter <= max_iter:
        status_old = status
        # calculate flops and status
        model.set_parameters()
        flops_prune = get_flops(model)
        status = flops_prune / flops
        params_prune = get_parameters(model)
        params_compression_ratio = params_prune / params

        string = 'Iter {:<3}: current step={:1.8f}, current threshold={:2.8f}, status (FLOPs ratio) = {:2.4f}, ' \
                 'params ratio = {:2.4f}.\n'\
            .format(counter, step, model.pt, status, params_compression_ratio)
        print(string)

        if abs(status - target) > stop:
            # calculate the next step
            flag = False if counter == 1 else (status_old >= target) == (
                status < target)
            if flag:
                step /= 2
            # calculate the next threshold
            if status > target:
                model.pt += step
            elif status < target:
                model.pt -= step
                model.pt = max(model.pt, 0)

            counter += 1
            # deal with the unexpected status
            if model.pt < 0 or status <= 0:
                print('Status {} or threshold {} is out of range'.format(
                    status, model.pt))
                break
        else:
            print(
                'The target compression ratio is achieved. The loop is stopped'
            )
示例#2
0
    def reset_after_optimization(self):
        # During the reloading and testing phase, the searched sparse model is already loaded at initialization.
        # During the training phase, the searched sparse model is just there.
        if not self.converging and not self.args.test_only:
            self.model.get_model().reset_after_searching()
            self.converging = True
            self.optimizer = utility.make_optimizer_dhp(
                self.args, self.model, converging=self.converging)
            self.scheduler = utility.make_scheduler_dhp(
                self.args,
                self.optimizer,
                int(self.args.lr_decay_step.split('+')[1]),
                converging=self.converging)

        # print(self.model.get_model())
        print(self.model.get_model(), file=self.ckp.log_file)

        # calculate flops and number of parameters
        self.flops_prune = get_flops(self.model.get_model())
        self.flops_compression_ratio = self.flops_prune / self.flops
        self.params_prune = get_parameters(self.model.get_model())
        self.params_compression_ratio = self.params_prune / self.params

        # reset tensorboardX summary
        if not self.args.test_only and self.args.summary:
            self.writer = SummaryWriter(os.path.join(self.args.dir_save,
                                                     self.args.save),
                                        comment='converging')

        # get the searching epochs
        if os.path.exists(os.path.join(self.ckp.dir, 'epochs.pt')):
            self.epochs_searching = torch.load(
                os.path.join(self.ckp.dir, 'epochs.pt'))
示例#3
0
    def reset_after_searching(self):
        # Phase 1 & 3, model reset here.
        # PHase 2 & 4, model reset at initialization

        # In Phase 1 & 3, the optimizer and scheduler are reset.
        # In Phase 2, the optimizer and scheduler is not used.
        # In Phase 4, the optimizer and scheduler is already set during the initialization of the trainer.
        # during the converging stage, self.converging =True. Do not need to set lr_adjust_flag in make_optimizer_hinge
        #   and make_scheduler_hinge.
        if not self.converging and not self.args.test_only:
            self.model.get_model().reset_after_searching()
            self.converging = True
            del self.optimizer, self.scheduler
            torch.cuda.empty_cache()
            decay = self.args.decay if len(self.args.decay.split('+')) == 1 else self.args.decay.split('+')[1]
            self.optimizer = utility.make_optimizer_dhp(self.args, self.model, converging=self.converging)
            self.scheduler = utility.make_scheduler_dhp(self.args, self.optimizer, decay,
                                                        converging=self.converging)

        self.flops_prune = get_flops(self.model.get_model())
        self.flops_compression_ratio = self.flops_prune / self.flops
        self.params_prune = get_parameters(self.model.get_model())
        self.params_compression_ratio = self.params_prune / self.params

        if not self.args.test_only and self.args.summary:
            self.writer = SummaryWriter(os.path.join(self.args.dir_save, self.args.save), comment='converging')
        if os.path.exists(os.path.join(self.ckp.dir, 'epochs.pt')):
            self.epochs_searching = torch.load(os.path.join(self.ckp.dir, 'epochs.pt'))
    def __init__(self,
                 args,
                 loader,
                 my_model,
                 my_loss,
                 ckp=None,
                 writer=None,
                 converging=False,
                 model_teacher=None):
        self.args = args
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.model_teacher = model_teacher
        self.loss = my_loss
        self.writer = writer
        self.loss_mse = nn.MSELoss(
        ) if self.args.distillation_inter == 'kd' else None
        if args.data_train.find('CIFAR') >= 0:
            self.input_dim = (3, 32, 32)
        elif args.data_train.find('Tiny_ImageNet') >= 0:
            self.input_dim = (3, 64, 64)
        else:
            self.input_dim = (3, 224, 224)
        set_output_dimension(self.model.get_model(), self.input_dim)
        self.flops = get_flops(self.model.get_model())
        self.flops_prune = self.flops  # at initialization, no pruning is conducted.
        self.flops_compression_ratio = self.flops_prune / self.flops
        self.params = get_parameters(self.model.get_model())
        self.params_prune = self.params
        self.params_compression_ratio = self.params_prune / self.params
        self.flops_ratio_log = []
        self.params_ratio_log = []
        self.converging = converging
        self.ckp.write_log(
            '\nThe computation complexity and number of parameters of the current network is as follows.'
            '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format(
                self.flops / 10.**9, self.params / 10.**3))
        self.flops_another = get_model_flops(self.model.get_model(),
                                             self.input_dim, False)
        self.ckp.write_log(
            'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated '
            'Flops are the same.\n'.format(self.flops_another / 10.**9))

        self.optimizer = utility.make_optimizer_dhp(args,
                                                    self.model,
                                                    ckp=ckp,
                                                    converging=converging)
        self.scheduler = utility.make_scheduler_dhp(args,
                                                    self.optimizer,
                                                    args.decay.split('+')[0],
                                                    converging=converging)
        self.device = torch.device('cpu' if args.cpu else 'cuda')

        if args.model.find('INQ') >= 0:
            self.inq_steps = args.inq_steps
        else:
            self.inq_steps = None
示例#5
0
    def __init__(self,
                 args,
                 loader,
                 my_model,
                 my_loss,
                 ckp,
                 writer=None,
                 converging=False):
        self.args = args
        self.scale = args.scale
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.writer = writer
        self.optimizer = utility.make_optimizer_dhp(args,
                                                    self.model,
                                                    ckp,
                                                    converging=converging)
        self.scheduler = utility.make_scheduler_dhp(
            args,
            self.optimizer,
            int(args.lr_decay_step.split('+')[0]),
            converging=converging)
        if self.args.model.lower().find(
                'unet') >= 0 or self.args.model.lower().find('dncnn') >= 0:
            self.input_dim = (1, args.input_dim, args.input_dim)
        else:
            self.input_dim = (3, args.input_dim, args.input_dim)
        # embed()
        set_output_dimension(self.model.get_model(), self.input_dim)
        self.flops = get_flops(self.model.get_model())
        self.flops_prune = self.flops  # at initialization, no pruning is conducted.
        self.flops_compression_ratio = self.flops_prune / self.flops
        self.params = get_parameters(self.model.get_model())

        self.params_prune = self.params
        self.params_compression_ratio = self.params_prune / self.params
        self.flops_ratio_log = []
        self.params_ratio_log = []
        self.converging = converging
        self.ckp.write_log(
            '\nThe computation complexity and number of parameters of the current network is as follows.'
            '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format(
                self.flops / 10.**9, self.params / 10.**3))
        self.flops_another = get_model_flops(self.model.get_model(),
                                             self.input_dim, False)
        self.ckp.write_log(
            'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated '
            'Flops are the same.\n'.format(self.flops_another / 10.**9))

        self.error_last = 1e8
    def __init__(self, args, loader, my_model, my_loss, ckp, writer=None):
        self.args = args
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.writer = writer

        if args.data_train.find('CIFAR') >= 0:
            self.input_dim = (3, 32, 32)
        elif args.data_train.find('Tiny_ImageNet') >= 0:
            self.input_dim = (3, 64, 64)
        else:
            self.input_dim = (3, 224, 224)
        set_output_dimension(self.model.get_model(), self.input_dim)
        self.flops = get_flops(self.model.get_model())
        self.params = get_parameters(self.model.get_model())
        self.ckp.write_log(
            '\nThe computation complexity and number of parameters of the current network is as follows.'
            '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format(
                self.flops / 10.**9, self.params / 10.**3))
        self.flops_another = get_model_flops(self.model.get_model(),
                                             self.input_dim, False)
        self.ckp.write_log(
            'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated '
            'Flops are the same.\n'.format(self.flops_another / 10.**9))

        self.optimizer = utility.make_optimizer(args, self.model, ckp=ckp)
        self.scheduler = utility.make_scheduler(args, self.optimizer)
        self.device = torch.device('cpu' if args.cpu else 'cuda')

        if args.model.find('INQ') >= 0:
            self.inq_steps = args.inq_steps
        else:
            self.inq_steps = None
示例#7
0
    # Step 3: Pruning -> prune the derived sparse model and prepare the trainer instance for finetuning or testing
    # ==================================================================================================================
    t.reset_after_optimization()
    if args.print_model:
        print(t.model.get_model())
        print(t.model.get_model(), file=checkpoint.log_file)

    # ==================================================================================================================
    # Step 4: Continue the training / Testing -> continue to train the pruned model to have a higher accuracy.
    # ==================================================================================================================
    while not t.terminate():
        t.train()
        t.test()

    set_output_dimension(model.get_model(), t.input_dim)
    flops = get_flops(model.get_model())
    params = get_parameters(model.get_model())
    print(
        '\nThe computation complexity and number of parameters of the current network is as follows.'
        '\nFlops: {:.4f} [G]\tParams {:.2f} [k]\n'.format(
            flops / 10.**9, params / 10.**3))

    if args.summary:
        t.writer.close()
    if args.print_model:
        print(t.model.get_model())
        print(t.model.get_model(), file=checkpoint.log_file)
    # for m in t.model.parameters():
    #     print(m.shape)
    checkpoint.done()
示例#8
0
    def train(self):
        self.loss.step()
        epoch = self.scheduler.last_epoch + 1
        learning_rate = self.scheduler.get_lr()[0]
        idx_scale = self.args.scale
        if not self.converging:
            stage = 'Searching Stage'
        else:
            stage = 'Finetuning Stage (Searching Epoch {})'.format(
                self.epochs_searching)
        self.ckp.write_log('\n[Epoch {}]\tLearning rate: {:.2e}\t{}'.format(
            epoch, Decimal(learning_rate), stage))
        self.loss.start_log()
        self.model.train()
        timer_data, timer_model = utility.timer(), utility.timer()

        for batch, (lr, hr, _) in enumerate(self.loader_train):
            # if batch <= 1200:
            lr, hr = self.prepare([lr, hr])

            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()
            sr = self.model(idx_scale, lr)
            loss = self.loss(sr, hr)

            if loss.item() < self.args.skip_threshold * self.error_last:
                # Adam
                loss.backward()
                self.optimizer.step()
                # proximal operator
                if not self.converging:
                    self.model.get_model().proximal_operator(learning_rate)
                    # check the compression ratio
                    if (batch +
                            1) % self.args.compression_check_frequency == 0:
                        # set the channels of the potential pruned model
                        self.model.get_model().set_parameters()
                        # update the flops and number of parameters
                        self.flops_prune = get_flops(self.model.get_model())
                        self.flops_compression_ratio = self.flops_prune / self.flops
                        self.params_prune = get_parameters(
                            self.model.get_model())
                        self.params_compression_ratio = self.params_prune / self.params
                        self.flops_ratio_log.append(
                            self.flops_compression_ratio)
                        self.params_ratio_log.append(
                            self.params_compression_ratio)
                        if self.terminate():
                            break
                    if (batch + 1) % 1000 == 0:
                        self.model.get_model().latent_vector_distribution(
                            epoch, batch + 1, self.ckp.dir)
                        self.model.get_model().per_layer_compression_ratio(
                            epoch, batch + 1, self.ckp.dir)

            else:
                print('Skip this batch {}! (Loss: {}) (Threshold: {})'.format(
                    batch + 1, loss.item(),
                    self.args.skip_threshold * self.error_last))

            timer_model.hold()

            if (batch + 1) % self.args.print_every == 0:
                self.ckp.write_log(
                    '[{}/{}]\t{}\t{:.3f}+{:.3f}s'
                    '\tFlops Ratio: {:.2f}% = {:.4f} G / {:.4f} G'
                    '\tParams Ratio: {:.2f}% = {:.2f} k / {:.2f} k'.format(
                        (batch + 1) * self.args.batch_size,
                        len(self.loader_train.dataset),
                        self.loss.display_loss(batch), timer_model.release(),
                        timer_data.release(),
                        self.flops_compression_ratio * 100,
                        self.flops_prune / 10.**9, self.flops / 10.**9,
                        self.params_compression_ratio * 100,
                        self.params_prune / 10.**3, self.params / 10.**3))
            timer_data.tic()
            # else:
            #     break

        self.loss.end_log(len(self.loader_train))
        self.error_last = self.loss.log[-1, -1]
        # self.error_last = loss
        self.scheduler.step()
    def train(self):
        epoch, lr = self.start_epoch()
        self.model.begin(
            epoch, self.ckp
        )  #TODO: investigate why not using self.model.train() directly
        self.loss.start_log()
        timer_data, timer_model = utility.timer(), utility.timer()
        n_samples = 0

        for batch, (img, label) in enumerate(self.loader_train):
            # embed()
            if (self.args.data_train == 'ImageNet' or self.args.model.lower()
                    == 'efficientnet_hh') and not self.converging:
                if self.args.model == 'ResNet_ImageNet_HH' or self.args.model == 'RegNet_ImageNet_HH':
                    divider = 4
                else:
                    divider = 2
                print('Divider is {}'.format(divider))
                batch_size = img.shape[0] // divider
                img = img[:batch_size]
                label = label[:batch_size]
            # embed()
            img, label = self.prepare(img, label)
            n_samples += img.size(0)

            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()
            prediction = self.model(img)
            # embed()
            if (not self.converging and self.args.distillation_stage == 'c') or \
                    (self.converging and not self.args.distillation_final):
                loss, _ = self.loss(prediction, label)
            else:
                with torch.no_grad():
                    prediction_teacher = self.model_teacher(img)
                if not self.args.distillation_inter:
                    prediction = [prediction]
                    prediction_teacher = [prediction_teacher]
                loss, _ = self.loss(prediction[0], label)

                if self.args.distillation_final == 'kd':
                    loss_distill_final = distillation(prediction[0],
                                                      prediction_teacher[0],
                                                      T=4)
                    loss = 0.4 * loss_distill_final + 0.6 * loss
                elif self.args.distillation_inter == 'sp':
                    loss_distill_final = similarity_preserving(
                        prediction[0], prediction_teacher[0]) * 3000
                    loss = loss_distill_final + loss
                if self.args.distillation_inter == 'kd':
                    loss_distill_inter = 0
                    for p, pt in zip(prediction[1], prediction_teacher[1]):
                        loss_distill_inter += self.loss_mse(p, pt)
                        # embed()
                    loss_distill_inter = loss_distill_inter / len(
                        prediction[1]) * self.args.distill_beta
                    loss = loss_distill_inter + loss
                elif self.args.distillation_inter == 'sp':
                    loss_distill_inter = 0
                    for p, pt in zip(prediction[1], prediction_teacher[1]):
                        loss_distill_inter += similarity_preserving(p, pt)
                    loss_distill_inter = loss_distill_inter / len(
                        prediction[1]) * 3000 * self.args.distill_beta
                    # loss_distill_inter = similarity_preserving(prediction[1], prediction_teacher[1])
                    loss = loss_distill_inter + loss
                # else: self.args.distillation_inter == '', do nothing here

            # SGD
            loss.backward()
            self.optimizer.step()
            if not self.converging and self.args.use_prox:
                # if epoch > 5:
                # proximal operator
                self.model.get_model().proximal_operator(lr)
                if (batch + 1) % self.args.compression_check_frequency == 0:
                    self.model.get_model().set_parameters()
                    self.flops_prune = get_flops(self.model.get_model())
                    self.flops_compression_ratio = self.flops_prune / self.flops
                    self.params_prune = get_parameters(self.model.get_model())
                    self.params_compression_ratio = self.params_prune / self.params
                    self.flops_ratio_log.append(self.flops_compression_ratio)
                    self.params_ratio_log.append(self.params_compression_ratio)
                    if self.terminate():
                        break
                if (batch + 1) % 300 == 0:
                    self.model.get_model().latent_vector_distribution(
                        epoch, batch + 1, self.ckp.dir)
                    self.model.get_model().per_layer_compression_ratio(
                        epoch, batch + 1, self.ckp.dir)

            timer_model.hold()

            if (batch + 1) % self.args.print_every == 0:
                s = '{}/{} ({:.0f}%)\tNLL: {:.3f} Top1: {:.2f} / Top5: {:.2f}\t'.format(
                    n_samples, len(self.loader_train.dataset),
                    100.0 * n_samples / len(self.loader_train.dataset),
                    *(self.loss.log_train[-1, :] / n_samples))
                if self.converging or (not self.converging and
                                       self.args.distillation_stage == 's'):
                    if self.args.distillation_final:
                        s += 'DFinal: {:.3f} '.format(loss_distill_final)
                    if self.args.distillation_inter:
                        s += 'DInter: {:.3f}'.format(loss_distill_inter)
                    if self.args.distillation_final or self.args.distillation_inter:
                        s += '\t'
                s += 'Time: {:.1f}+{:.1f}s\t'.format(timer_model.release(),
                                                     timer_data.release())
                if hasattr(self, 'flops_compression_ratio') and hasattr(
                        self, 'params_compression_ratio'):
                    s += 'Flops: {:.2f}% = {:.4f} [G] / {:.4f} [G]\t' \
                         'Params: {:.2f}% = {:.2f} [k] / {:.2f} [k]'.format(
                         self.flops_compression_ratio * 100, self.flops_prune / 10. ** 9, self.flops / 10. ** 9,
                         self.params_compression_ratio * 100, self.params_prune / 10. ** 3, self.params / 10. ** 3)

                self.ckp.write_log(s)

            if self.args.summary:
                if (batch + 1) % 50 == 0:
                    for name, param in self.model.named_parameters():
                        if name.find('features') >= 0 and name.find(
                                'weight') >= 0:
                            self.writer.add_scalar(
                                'data/' + name,
                                param.clone().cpu().data.abs().mean().numpy(),
                                1000 * (epoch - 1) + batch)
                            if param.grad is not None:
                                self.writer.add_scalar(
                                    'data/' + name + '_grad',
                                    param.grad.clone().cpu().data.abs().mean().
                                    numpy(), 1000 * (epoch - 1) + batch)
                if (batch + 1) == 500:
                    for name, param in self.model.named_parameters():
                        if name.find('features') >= 0 and name.find(
                                'weight') >= 0:
                            self.writer.add_histogram(
                                name,
                                param.clone().cpu().data.numpy(),
                                1000 * (epoch - 1) + batch)
                            if param.grad is not None:
                                self.writer.add_histogram(
                                    name + '_grad',
                                    param.grad.clone().cpu().data.numpy(),
                                    1000 * (epoch - 1) + batch)
            timer_data.tic()
            if not self.converging and epoch == self.args.epochs_grad and batch == 1:
                break
        self.model.log(self.ckp)  # TODO: why this is used?
        self.loss.end_log(len(self.loader_train.dataset))
示例#10
0
    def train(self):
        epoch, lr = self.start_epoch()
        self.model.begin(epoch, self.ckp) #TODO: investigate why not using self.model.train() directly
        self.loss.start_log()
        timer_data, timer_model = utility.timer(), utility.timer()
        n_samples = 0

        for batch, (img, label) in enumerate(self.loader_train):
            img, label = self.prepare(img, label)
            n_samples += img.size(0)

            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()
            prediction = self.model(img)
            loss, _ = self.loss(prediction, label)

            # SGD
            loss.backward()
            self.optimizer.step()
            # proximal operator
            if not self.converging:
                self.model.get_model().proximal_operator(lr)
                if (batch + 1) % self.args.compression_check_frequency == 0:
                    self.model.get_model().set_parameters()
                    self.flops_prune = get_flops(self.model.get_model())
                    self.flops_compression_ratio = self.flops_prune / self.flops
                    self.params_prune = get_parameters(self.model.get_model())
                    self.params_compression_ratio = self.params_prune / self.params
                    self.flops_ratio_log.append(self.flops_compression_ratio)
                    self.params_ratio_log.append(self.params_compression_ratio)
                    # if self.terminate():
                    #     break
                if (batch + 1) % 300 == 0:
                    self.model.get_model().latent_vector_distribution(epoch, batch + 1, self.ckp.dir)
                    self.model.get_model().per_layer_compression_ratio(epoch, batch + 1, self.ckp.dir)

            timer_model.hold()

            if (batch + 1) % self.args.print_every == 0:
                self.ckp.write_log('{}/{} ({:.0f}%)\t'
                    'NLL: {:.3f}\tTop1: {:.2f} / Top5: {:.2f}\t'
                    'Time: {:.1f}+{:.1f}s\t'
                    'Flops Ratio: {:.2f}% = {:.4f} [G] / {:.4f} [G]\t'
                    'Params Ratio: {:.2f}% = {:.2f} [k] / {:.2f} [k]'.format(
                    n_samples, len(self.loader_train.dataset), 100.0 * n_samples / len(self.loader_train.dataset),
                    *(self.loss.log_train[-1, :] / n_samples),
                    timer_model.release(), timer_data.release(),
                    self.flops_compression_ratio * 100, self.flops_prune / 10. ** 9, self.flops / 10. ** 9,
                    self.params_compression_ratio * 100, self.params_prune / 10. ** 3, self.params / 10. ** 3))
            if not self.converging and self.terminate():
                break

            if self.args.summary:
                if (batch + 1) % 50 == 0:
                    for name, param in self.model.named_parameters():
                        if name.find('features') >= 0 and name.find('weight') >= 0:
                            self.writer.add_scalar('data/' + name, param.clone().cpu().data.abs().mean().numpy(),
                                                   1000 * (epoch - 1) + batch)
                            if param.grad is not None:
                                self.writer.add_scalar('data/' + name + '_grad',
                                                       param.grad.clone().cpu().data.abs().mean().numpy(),
                                                       1000 * (epoch - 1) + batch)
                if (batch + 1) == 500:
                    for name, param in self.model.named_parameters():
                        if name.find('features') >= 0 and name.find('weight') >= 0:
                            self.writer.add_histogram(name, param.clone().cpu().data.numpy(), 1000 * (epoch - 1) + batch)
                            if param.grad is not None:
                                self.writer.add_histogram(name + '_grad', param.grad.clone().cpu().data.numpy(),
                                                      1000 * (epoch - 1) + batch)

            timer_data.tic()
        self.model.log(self.ckp) # TODO: why this is used?
        self.loss.end_log(len(self.loader_train.dataset))
示例#11
0
    # Step 3: Pruning -> prune the derived sparse model and prepare the trainer instance for finetuning or testing
    # ==================================================================================================================

    t.reset_after_searching()
    if args.print_model:
        print(t.model.get_model())
        print(t.model.get_model(), file=checkpoint.log_file)

    # ==================================================================================================================
    # Step 4: Fintuning/Testing -> finetune the pruned model to have a higher accuracy.
    # ==================================================================================================================

    while not t.terminate():
        t.train()
        t.test()

    set_output_dimension(network_model.get_model(), t.input_dim)
    flops = get_flops(network_model.get_model())
    params = get_parameters(network_model.get_model())
    print(
        '\nThe computation complexity and number of parameters of the current network is as follows.'
        '\nFlops: {:.4f} [G]\tParams {:.2f} [k]\n'.format(
            flops / 10.**9, params / 10.**3))

    if args.summary:
        t.writer.close()
    if args.print_model:
        print(t.model.get_model())
        print(t.model.get_model(), file=checkpoint.log_file)
    checkpoint.done()