Пример #1
0
    def __init__(self,
                 dataset,
                 class_num,
                 image_mean,
                 image_std,
                 network,
                 multi_scales,
                 is_flip,
                 devices,
                 verbose=False,
                 save_path=None,
                 show_image=False):
        self.dataset = dataset
        self.ndata = self.dataset.get_length()
        self.class_num = class_num
        self.image_mean = image_mean
        self.image_std = image_std
        self.multi_scales = multi_scales
        self.is_flip = is_flip
        self.network = network
        self.devices = devices

        self.context = mp.get_context('spawn')
        self.val_func = None
        self.results_queue = self.context.Queue(self.ndata)

        self.verbose = verbose
        self.save_path = save_path
        if save_path is not None:
            ensure_dir(save_path)
        self.show_image = show_image
Пример #2
0
 def save_and_link_checkpoint(self, snapshot_dir):
     if self.local_rank > 0:
         return
     ensure_dir(snapshot_dir)
     current_iter_checkpoint = osp.join(
         snapshot_dir, 'iter-{}.pth'.format(self.state.iteration))
     self.save_checkpoint(current_iter_checkpoint)
Пример #3
0
    def __init__(self, dataset, class_num, image_mean, image_std, network,
                 multi_scales, is_flip, devices=0, out_idx=0, threds=3, config=None, logger=None,
                 verbose=False, save_path=None, show_image=False, show_prediction=False):
        self.dataset = dataset
        self.ndata = self.dataset.get_length()
        self.class_num = class_num
        self.image_mean = image_mean
        self.image_std = image_std
        self.multi_scales = multi_scales
        self.is_flip = is_flip
        self.network = network
        self.devices = devices
        if type(self.devices) == int: self.devices = [self.devices]
        self.out_idx = out_idx
        self.threds = threds
        self.config = config
        self.logger = logger

        self.context = mp.get_context('spawn')
        self.val_func = None
        self.results_queue = self.context.Queue(self.ndata)

        self.verbose = verbose
        self.save_path = save_path
        if save_path is not None:
            ensure_dir(save_path)
        self.show_image = show_image
        self.show_prediction = show_prediction
Пример #4
0
 def save_latest_ckpt(self, snapshot_dir):
     if self.local_rank > 0:
         return
     ensure_dir(snapshot_dir)
     current_iter_checkpoint = osp.join(
         snapshot_dir, 'latest.pth')
     self.save_checkpoint(current_iter_checkpoint)
Пример #5
0
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        hha = data['hha_img']
        tsdf = data['tsdf']
        label_weight = data['label_weight']
        depth_mapping_3d = data['depth_mapping_3d']

        name = data['fn']
        sketch_gt = data['sketch_gt']
        pred, pred_sketch = self.eval_ssc(img, hha, tsdf, depth_mapping_3d,
                                          sketch_gt, device)

        results_dict = {
            'pred': pred,
            'label': label,
            'label_weight': label_weight,
            'name': name,
            'mapping': depth_mapping_3d
        }

        if self.save_path is not None:
            ensure_dir(self.save_path)
            ensure_dir(self.save_path + '_sketch')
            fn = name + '.npy'
            np.save(os.path.join(self.save_path, fn), pred)
            np.save(os.path.join(self.save_path + '_sketch', fn), pred_sketch)
            logger.info('Save the pred npz ' + fn)

        return results_dict
Пример #6
0
 def save_and_link_checkpoint(self, snapshot_dir, log_dir, log_dir_link):
     ensure_dir(snapshot_dir)
     if not osp.exists(log_dir_link):
         link_file(log_dir, log_dir_link)
     current_epoch_checkpoint = osp.join(
         snapshot_dir, 'epoch-{}.pth'.format(self.state.epoch))
     self.save_checkpoint(current_epoch_checkpoint)
     last_epoch_checkpoint = osp.join(snapshot_dir, 'epoch-last.pth')
     link_file(current_epoch_checkpoint, last_epoch_checkpoint)
Пример #7
0
def get_logger(log_dir=None, log_file=None, formatter=LogFormatter):
    logger = logging.getLogger()
    logger.setLevel(_default_level)
    del logger.handlers[:]

    if log_dir and log_file:
        pyt_utils.ensure_dir(log_dir)
        LogFormatter.log_fout = True
        file_handler = logging.FileHandler(log_file, mode='a')
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S'))
    stream_handler.setLevel(0)
    logger.addHandler(stream_handler)
    return logger
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        hha = data['hha_img']
        name = data['fn']
        pred = self.sliding_eval_rgbdepth(img, hha, config.eval_crop_size,
                                          config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            ensure_dir(self.save_path)
            ensure_dir(self.save_path + '_color')

            fn = name + '.png'

            'save colored result'
            result_img = Image.fromarray(pred.astype(np.uint8), mode='P')
            class_colors = get_class_colors()
            palette_list = list(np.array(class_colors).flat)
            if len(palette_list) < 768:
                palette_list += [0] * (768 - len(palette_list))
            result_img.putpalette(palette_list)
            result_img.save(os.path.join(self.save_path + '_color', fn))

            'save raw result'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict
Пример #9
0
def csgd_train_main(local_rank,
                    cfg: BaseConfigByEpoch,
                    target_deps,
                    succeeding_strategy,
                    pacesetter_dict,
                    centri_strength,
                    pruned_weights,
                    net=None,
                    train_dataloader=None,
                    val_dataloader=None,
                    show_variables=False,
                    convbuilder=None,
                    init_hdf5=None,
                    no_l2_keywords='depth',
                    use_nesterov=False,
                    load_weights_keyword=None,
                    keyword_to_lr_mult=None,
                    auto_continue=False,
                    save_hdf5_epochs=10000):

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    clusters_save_path = os.path.join(cfg.output_dir, 'clusters.npy')

    with Engine(local_rank=local_rank) as engine:
        engine.setup_log(name='train',
                         log_dir=cfg.output_dir,
                         file_name='log.txt')

        # ----------------------------- build model ------------------------------
        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)
        if net is None:
            net_fn = get_model_fn(cfg.dataset_name, cfg.network_type)
            model = net_fn(cfg, convbuilder)
        else:
            model = net
        model = model.cuda()
        # ----------------------------- model done ------------------------------

        # ---------------------------- prepare data -------------------------
        if train_dataloader is None:
            train_data = create_dataset(cfg.dataset_name,
                                        cfg.dataset_subset,
                                        cfg.global_batch_size,
                                        distributed=engine.distributed)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_data = create_dataset(cfg.dataset_name,
                                      'val',
                                      global_batch_size=100,
                                      distributed=False)
        engine.echo('NOTE: Data prepared')
        engine.echo(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))
        # ----------------------------- data done --------------------------------

        # ------------------------ parepare optimizer, scheduler, criterion -------
        if no_l2_keywords is None:
            no_l2_keywords = []
        if type(no_l2_keywords) is not list:
            no_l2_keywords = [no_l2_keywords]
        # For a target parameter, cancel its weight decay in optimizer, because the weight decay will be later encoded in the decay mat
        conv_idx = 0
        for k, v in model.named_parameters():
            if v.dim() != 4:
                continue
            print('prune {} from {} to {}'.format(conv_idx,
                                                  target_deps[conv_idx],
                                                  cfg.deps[conv_idx]))
            if target_deps[conv_idx] < cfg.deps[conv_idx]:
                no_l2_keywords.append(k.replace(KERNEL_KEYWORD, 'conv'))
                no_l2_keywords.append(k.replace(KERNEL_KEYWORD, 'bn'))
            conv_idx += 1
        print('no l2: ', no_l2_keywords)
        optimizer = get_optimizer(engine,
                                  cfg,
                                  model,
                                  no_l2_keywords=no_l2_keywords,
                                  use_nesterov=use_nesterov,
                                  keyword_to_lr_mult=keyword_to_lr_mult)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()
        # --------------------------------- done -------------------------------

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer)

        if engine.distributed:
            torch.cuda.set_device(local_rank)
            engine.echo('Distributed training, device {}'.format(local_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                broadcast_buffers=False,
            )
        else:
            assert torch.cuda.device_count() == 1
            engine.echo('Single GPU training')

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)
        if init_hdf5:
            engine.load_hdf5(init_hdf5,
                             load_weights_keyword=load_weights_keyword)
        if auto_continue:
            assert cfg.init_weights is None
            engine.load_checkpoint(get_last_checkpoint(cfg.output_dir))
        if show_variables:
            engine.show_variables()

        #   ===================================== prepare the clusters and matrices for C-SGD ==========
        kernel_namedvalue_list = engine.get_all_conv_kernel_namedvalue_as_list(
        )

        if os.path.exists(clusters_save_path):
            layer_idx_to_clusters = np.load(clusters_save_path,
                                            allow_pickle=True).item()
        else:
            if local_rank == 0:
                layer_idx_to_clusters = get_layer_idx_to_clusters(
                    kernel_namedvalue_list=kernel_namedvalue_list,
                    target_deps=target_deps,
                    pacesetter_dict=pacesetter_dict)
                if pacesetter_dict is not None:
                    for follower_idx, pacesetter_idx in pacesetter_dict.items(
                    ):
                        if pacesetter_idx in layer_idx_to_clusters:
                            layer_idx_to_clusters[
                                follower_idx] = layer_idx_to_clusters[
                                    pacesetter_idx]

                np.save(clusters_save_path, layer_idx_to_clusters)
            else:
                while not os.path.exists(clusters_save_path):
                    time.sleep(10)
                    print('sleep, waiting for process 0 to calculate clusters')
                layer_idx_to_clusters = np.load(clusters_save_path,
                                                allow_pickle=True).item()

        param_name_to_merge_matrix = generate_merge_matrix_for_kernel(
            deps=cfg.deps,
            layer_idx_to_clusters=layer_idx_to_clusters,
            kernel_namedvalue_list=kernel_namedvalue_list)
        add_vecs_to_merge_mat_dicts(param_name_to_merge_matrix)
        param_name_to_decay_matrix = generate_decay_matrix_for_kernel_and_vecs(
            deps=cfg.deps,
            layer_idx_to_clusters=layer_idx_to_clusters,
            kernel_namedvalue_list=kernel_namedvalue_list,
            weight_decay=cfg.weight_decay,
            weight_decay_bias=cfg.weight_decay_bias,
            centri_strength=centri_strength)
        print(param_name_to_decay_matrix.keys())
        print(param_name_to_merge_matrix.keys())

        conv_idx = 0
        param_to_clusters = {}
        for k, v in model.named_parameters():
            if v.dim() != 4:
                continue
            if conv_idx in layer_idx_to_clusters:
                for clsts in layer_idx_to_clusters[conv_idx]:
                    if len(clsts) > 1:
                        param_to_clusters[v] = layer_idx_to_clusters[conv_idx]
                        break
            conv_idx += 1
        #   ============================================================================================

        # ------------ do training ---------------------------- #
        engine.log("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch
        last_epoch_done_iters = iteration % iters_per_epoch

        if done_epochs == 0 and last_epoch_done_iters == 0:
            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        recorded_train_time = 0
        recorded_train_examples = 0

        collected_train_loss_sum = 0
        collected_train_loss_count = 0

        for epoch in range(done_epochs, cfg.max_epochs):

            if engine.distributed and hasattr(train_data, 'train_sampler'):
                train_data.train_sampler.set_epoch(epoch)

            if epoch == done_epochs:
                pbar = tqdm(range(iters_per_epoch - last_epoch_done_iters))
            else:
                pbar = tqdm(range(iters_per_epoch))

            if epoch == 0 and local_rank == 0:
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str='Init',
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_data,
                                             dataset_name=cfg.dataset_name)

                # load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                train_net_time_start = time.time()
                acc, acc5, loss = train_one_step(
                    model,
                    data,
                    label,
                    optimizer,
                    criterion,
                    param_name_to_merge_matrix=param_name_to_merge_matrix,
                    param_name_to_decay_matrix=param_name_to_decay_matrix)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                for module in model.modules():
                    if hasattr(module, 'set_cur_iter'):
                        module.set_cur_iter(iteration)

                if iteration % cfg.tb_iter_period == 0 and engine.world_rank == 0:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)
                    deviation_sum = 0
                    for param, clusters in param_to_clusters.items():
                        pvalue = param.detach().cpu().numpy()
                        for cl in clusters:
                            if len(cl) == 1:
                                continue
                            selected = pvalue[cl, :, :, :]
                            mean_kernel = np.mean(selected,
                                                  axis=0,
                                                  keepdims=True)
                            diff = selected - mean_kernel
                            deviation_sum += np.sum(diff**2)
                    tb_writer.add_scalars('deviation_sum',
                                          {'Train': deviation_sum}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                if epoch >= cfg.max_epochs - COLLECT_TRAIN_LOSS_EPOCHS:
                    collected_train_loss_sum += loss.item()
                    collected_train_loss_count += 1

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                iteration += 1

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and
                                                    engine.world_rank == 0):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            engine.update_iteration(iteration)
            engine.save_latest_ckpt(cfg.output_dir)

            if (epoch + 1) % save_hdf5_epochs == 0:
                engine.save_hdf5(
                    os.path.join(cfg.output_dir,
                                 'epoch-{}.hdf5'.format(epoch)))

            if local_rank == 0 and \
                    cfg.val_epoch_period > 0 and (epoch >= cfg.max_epochs - 10 or epoch % cfg.val_epoch_period == 0):
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str=discrip_str,
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            if iteration >= max_iters:
                break

        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters),
                    int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(
                cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
        if collected_train_loss_count > 0:
            engine.log(
                'TRAIN LOSS collected over last {} epochs: {:.6f}'.format(
                    COLLECT_TRAIN_LOSS_EPOCHS,
                    collected_train_loss_sum / collected_train_loss_count))

    if local_rank == 0:
        csgd_prune_and_save(engine=engine,
                            layer_idx_to_clusters=layer_idx_to_clusters,
                            save_file=pruned_weights,
                            succeeding_strategy=succeeding_strategy,
                            new_deps=target_deps)
Пример #10
0
def ding_train(cfg:BaseConfigByEpoch, net=None, train_dataloader=None, val_dataloader=None, show_variables=False, convbuilder=None, beginning_msg=None,
               init_hdf5=None, no_l2_keywords=None, gradient_mask=None, use_nesterov=False):

    # LOCAL_RANK = 0
    #
    # num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    # is_distributed = num_gpus > 1
    #
    # if is_distributed:
    #     torch.cuda.set_device(LOCAL_RANK)
    #     torch.distributed.init_process_group(
    #         backend="nccl", init_method="env://"
    #     )
    #     synchronize()
    #
    # torch.backends.cudnn.benchmark = True

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    with Engine() as engine:

        is_main_process = (engine.world_rank == 0) #TODO correct?

        logger = engine.setup_log(
            name='train', log_dir=cfg.output_dir, file_name='log.txt')

        # -- typical model components model, opt,  scheduler,  dataloder --#
        if net is None:
            net = get_model_fn(cfg.dataset_name, cfg.network_type)

        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)

        model = net(cfg, convbuilder).cuda()

        if train_dataloader is None:
            train_dataloader = create_dataset(cfg.dataset_name, cfg.dataset_subset, cfg.global_batch_size)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_dataloader = create_dataset(cfg.dataset_name, 'val', batch_size=100)    #TODO 100?

        print('NOTE: Data prepared')
        print('NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'.format(cfg.global_batch_size, torch.cuda.device_count(), torch.cuda.memory_allocated()))

        # device = torch.device(cfg.device)
        # model.to(device)
        # model.cuda()

        if no_l2_keywords is None:
            no_l2_keywords = []
        optimizer = get_optimizer(cfg, model, no_l2_keywords=no_l2_keywords, use_nesterov=use_nesterov)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()

        # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")

        engine.register_state(
            scheduler=scheduler, model=model, optimizer=optimizer)

        if engine.distributed:
            print('Distributed training, engine.world_rank={}'.format(engine.world_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[engine.world_rank],
                broadcast_buffers=False, )
            # model = DistributedDataParallel(model, delay_allreduce=True)
        elif torch.cuda.device_count() > 1:
            print('Single machine multiple GPU training')
            model = torch.nn.parallel.DataParallel(model)

        # for k, v in model.named_parameters():
        #     if v.dim() in [2, 4]:
        #         torch.nn.init.xavier_normal_(v)
        #         print('init {} as xavier_normal'.format(k))
        #     if 'bias' in k and 'bn' not in k.lower():
        #         torch.nn.init.zeros_(v)
        #         print('init {} as zero'.format(k))

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights, is_restore=True)

        if init_hdf5:
            engine.load_hdf5(init_hdf5)


        if show_variables:
            engine.show_variables()

        # ------------ do training ---------------------------- #
        if beginning_msg:
            engine.log(beginning_msg)
        logger.info("\n\nStart training with pytorch version {}".format(torch.__version__))

        iteration = engine.state.iteration
        # done_epochs = iteration // num_train_examples_per_epoch(cfg.dataset_name)
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch

        engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        # summary(model=model, input_size=(224, 224) if cfg.dataset_name == 'imagenet' else (32, 32), batch_size=cfg.global_batch_size)

        recorded_train_time = 0
        recorded_train_examples = 0

        if gradient_mask is not None:
            gradient_mask_tensor = {}
            for name, value in gradient_mask.items():
                gradient_mask_tensor[name] = torch.Tensor(value).cuda()
        else:
            gradient_mask_tensor = None

        for epoch in range(done_epochs, cfg.max_epochs):

            pbar = tqdm(range(iters_per_epoch))
            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)


            if cfg.val_epoch_period > 0 and epoch % cfg.val_epoch_period == 0:
                model.eval()
                val_iters = 500 if cfg.dataset_name == 'imagenet' else 100  # use batch_size=100 for val on ImagenNet and CIFAR
                eval_dict, _ = run_eval(val_dataloader, val_iters, model, criterion, discrip_str, dataset_name=cfg.dataset_name)
                val_top1_value = eval_dict['top1'].item()
                val_top5_value = eval_dict['top5'].item()
                val_loss_value = eval_dict['loss'].item()
                for tag, value in zip(tb_tags, [val_top1_value, val_top5_value, val_loss_value]):
                    tb_writer.add_scalars(tag, {'Val': value}, iteration)
                engine.log('validate at epoch {}, top1={:.5f}, top5={:.5f}, loss={:.6f}'.format(epoch, val_top1_value, val_top5_value, val_loss_value))
                model.train()

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)

                train_net_time_start = time.time()
                acc, acc5, loss = train_one_step(model, data, label, optimizer, criterion, if_accum_grad, gradient_mask_tensor=gradient_mask_tensor)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                if iteration % cfg.tb_iter_period == 0 and is_main_process:
                    for tag, value in zip(tb_tags, [acc.item(), acc5.item(), loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)


                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and is_main_process):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                iteration += 1
                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            if iteration >= max_iters:
                break
        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters), int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
Пример #11
0
def csgd_train_and_prune(cfg: BaseConfigByEpoch,
                         target_deps,
                         centri_strength,
                         pacesetter_dict,
                         succeeding_strategy,
                         pruned_weights,
                         extra_cfg,
                         net=None,
                         train_dataloader=None,
                         val_dataloader=None,
                         show_variables=False,
                         beginning_msg=None,
                         init_weights=None,
                         no_l2_keywords=None,
                         use_nesterov=False,
                         tensorflow_style_init=False,
                         iter=None):

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    clusters_save_path = os.path.join(cfg.output_dir, 'clusters.npy')
    print("cluster save path:{}".format(clusters_save_path))
    config = extra_cfg

    with Engine() as engine:

        is_main_process = (engine.world_rank == 0)  #TODO correct?

        logger = engine.setup_log(name='train',
                                  log_dir=cfg.output_dir,
                                  file_name='log.txt')

        saveName = "%s-%s.yaml" % (config['note'], config['dataset'])
        modelName = config['modelName']
        os.environ['CUDA_VISIBLE_DEVICES'] = config['gpu_available']
        device_ids = range(config['gpu_num'])
        trainSet = GoProDataset(sharp_root=config['train_sharp'],
                                blur_root=config['train_blur'],
                                resize_size=config['resize_size'],
                                patch_size=config['crop_size'],
                                phase='train')
        testSet = GoProDataset(sharp_root=config['test_sharp'],
                               blur_root=config['test_blur'],
                               resize_size=config['resize_size'],
                               patch_size=config['crop_size'],
                               phase='test')

        train_loader = DataLoader(trainSet,
                                  batch_size=config['batchsize'],
                                  shuffle=True,
                                  num_workers=4,
                                  drop_last=True,
                                  pin_memory=True)
        test_loader = DataLoader(testSet,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1,
                                 drop_last=False,
                                 pin_memory=True)

        print('NOTE: Data prepared')
        print(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(config['batchsize'], torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))

        model = net

        optimizer = get_optimizer(cfg, model, use_nesterov=use_nesterov)
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=config['step'],
                                             gamma=0.5)  # learning rates
        criterion = get_criterion(cfg).cuda()

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer,
                              cfg=cfg)

        model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids)

        # load weight of last prune iteration or the not pruned model
        if init_weights:
            engine.load_pth(init_weights)

        # for unet the last outconv will not be pruned
        kernel_namedvalue_list = engine.get_all_conv_kernel_namedvalue_as_list(
            remove='out')

        # cluster filters
        if os.path.exists(clusters_save_path):
            layer_idx_to_clusters = np.load(clusters_save_path,
                                            allow_pickle=True).item()
            print("cluster exist, load from {}".format(clusters_save_path))
        else:
            layer_idx_to_clusters = get_layer_idx_to_clusters(
                kernel_namedvalue_list=kernel_namedvalue_list,
                target_deps=target_deps,
                pacesetter_dict=pacesetter_dict)
            if pacesetter_dict is not None:
                for follower_idx, pacesetter_idx in pacesetter_dict.items():
                    if pacesetter_idx in layer_idx_to_clusters:
                        layer_idx_to_clusters[
                            follower_idx] = layer_idx_to_clusters[
                                pacesetter_idx]

            # print(layer_idx_to_clusters)

            np.save(clusters_save_path, layer_idx_to_clusters)

        csgd_save_file = os.path.join(cfg.output_dir, 'finish.pth')

        # if this prune iter has a trained model, then load it
        if os.path.exists(csgd_save_file):
            engine.load_pth(csgd_save_file)
        else:
            param_name_to_merge_matrix = generate_merge_matrix_for_kernel(
                deps=cfg.deps,
                layer_idx_to_clusters=layer_idx_to_clusters,
                kernel_namedvalue_list=kernel_namedvalue_list)
            param_name_to_decay_matrix = generate_decay_matrix_for_kernel_and_vecs(
                deps=cfg.deps,
                layer_idx_to_clusters=layer_idx_to_clusters,
                kernel_namedvalue_list=kernel_namedvalue_list,
                weight_decay=cfg.weight_decay,
                centri_strength=centri_strength)
            # if pacesetter_dict is not None:
            #     for follower_idx, pacesetter_idx in pacesetter_dict.items():
            #         follower_kernel_name = kernel_namedvalue_list[follower_idx].name
            #         pacesetter_kernel_name = kernel_namedvalue_list[follower_idx].name
            #         if pacesetter_kernel_name in param_name_to_merge_matrix:
            #             param_name_to_merge_matrix[follower_kernel_name] = param_name_to_merge_matrix[
            #                 pacesetter_kernel_name]
            #             param_name_to_decay_matrix[follower_kernel_name] = param_name_to_decay_matrix[
            #                 pacesetter_kernel_name]

            # add 2 para of bn and conv.bias to mat dicts to enable the c-sgd update rule
            add_vecs_to_mat_dicts(param_name_to_merge_matrix)

            if show_variables:
                engine.show_variables()

            if beginning_msg:
                engine.log(beginning_msg)

            logger.info("\n\nStart training with pytorch version {}".format(
                torch.__version__))

            iteration = engine.state.iteration
            startEpoch = config['start_epoch']
            max_epochs = config['max_epochs']

            engine.save_pth(os.path.join(cfg.output_dir, 'init.pth'))

            viz = Visdom(env=saveName)
            bestPSNR = config['bestPSNR']

            itr = '' if iter is None else str(iter)
            for epoch in range(startEpoch, max_epochs):
                # eval
                if epoch % config['save_epoch'] == 0:
                    with torch.no_grad():
                        model.eval()
                        avg_PSNR = 0
                        idx = 0
                        for test_data in test_loader:
                            idx += 1
                            test_data['L'] = test_data['L'].cuda()
                            sharp = model(test_data['L'])
                            sharp = sharp.detach().float().cpu()
                            sharp = util.tensor2uint(sharp)
                            test_data['H'] = util.tensor2uint(test_data['H'])
                            current_psnr = util.calculate_psnr(sharp,
                                                               test_data['H'],
                                                               border=0)

                            avg_PSNR += current_psnr
                            if idx % 100 == 0:
                                print("epoch {}: tested {}".format(epoch, idx))
                        avg_PSNR = avg_PSNR / idx
                        print("total PSNR : {:<4.2f}".format(avg_PSNR))
                        viz.line(X=[epoch],
                                 Y=[avg_PSNR],
                                 win='testPSNR-' + itr,
                                 opts=dict(title='psnr',
                                           legend=['valid_psnr']),
                                 update='append')
                        if avg_PSNR > bestPSNR:
                            bestPSNR = avg_PSNR
                            save_path = os.path.join(cfg.output_dir,
                                                     'finish.pth')
                            engine.save_pth(save_path)

                # train
                avg_loss = 0.0
                idx = 0
                model.train()
                for i, train_data in enumerate(train_loader):
                    idx += 1
                    train_data['L'] = train_data['L'].cuda()
                    train_data['H'] = train_data['H'].cuda()
                    optimizer.zero_grad()
                    loss = train_one_step(model, train_data['L'], train_data['H'], criterion,\
                                     optimizer,param_name_to_merge_matrix,\
                                     param_name_to_decay_matrix)

                    avg_loss += loss.item()
                    if idx % 100 == 0:
                        print("epoch {}: trained {}".format(epoch, idx))

                scheduler.step()
                avg_loss = avg_loss / idx
                print("epoch {}: total loss : {:<4.2f}, lr : {}".format(
                    epoch, avg_loss,
                    scheduler.get_lr()[0]))
                viz.line(X=[epoch],
                         Y=[avg_loss],
                         win='trainMSELoss-' + itr,
                         opts=dict(title='mse', legend=['train_mse']),
                         update='append')
            # engine.save_pth(os.path.join(cfg.output_dir, 'finish.pth'))

        csgd_prune_and_save(engine=engine,
                            layer_idx_to_clusters=layer_idx_to_clusters,
                            save_file=pruned_weights,
                            succeeding_strategy=succeeding_strategy,
                            new_deps=target_deps)
Пример #12
0
 def link_tb(self, source, target):
     ensure_dir(source)
     ensure_dir(target)
     link_file(source, target)
Пример #13
0
def train_main(local_rank,
               cfg: BaseConfigByEpoch,
               net=None,
               train_dataloader=None,
               val_dataloader=None,
               show_variables=False,
               convbuilder=None,
               init_hdf5=None,
               no_l2_keywords='depth',
               gradient_mask=None,
               use_nesterov=False,
               tensorflow_style_init=False,
               load_weights_keyword=None,
               keyword_to_lr_mult=None,
               auto_continue=False,
               lasso_keyword_to_strength=None,
               save_hdf5_epochs=10000):

    if no_l2_keywords is None:
        no_l2_keywords = []
    if type(no_l2_keywords) is not list:
        no_l2_keywords = [no_l2_keywords]

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    with Engine(local_rank=local_rank) as engine:
        engine.setup_log(name='train',
                         log_dir=cfg.output_dir,
                         file_name='log.txt')

        # ----------------------------- build model ------------------------------
        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)
        if net is None:
            net_fn = get_model_fn(cfg.dataset_name, cfg.network_type)
            model = net_fn(cfg, convbuilder)
        else:
            model = net
        model = model.cuda()
        # ----------------------------- model done ------------------------------

        # ---------------------------- prepare data -------------------------
        if train_dataloader is None:
            train_data = create_dataset(cfg.dataset_name,
                                        cfg.dataset_subset,
                                        cfg.global_batch_size,
                                        distributed=engine.distributed)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_data = create_dataset(cfg.dataset_name,
                                      'val',
                                      global_batch_size=100,
                                      distributed=False)
        engine.echo('NOTE: Data prepared')
        engine.echo(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))
        # ----------------------------- data done --------------------------------

        # ------------------------ parepare optimizer, scheduler, criterion -------
        optimizer = get_optimizer(engine,
                                  cfg,
                                  model,
                                  no_l2_keywords=no_l2_keywords,
                                  use_nesterov=use_nesterov,
                                  keyword_to_lr_mult=keyword_to_lr_mult)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()
        # --------------------------------- done -------------------------------

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer)

        if engine.distributed:
            torch.cuda.set_device(local_rank)
            engine.echo('Distributed training, device {}'.format(local_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                broadcast_buffers=False,
            )
        else:
            assert torch.cuda.device_count() == 1
            engine.echo('Single GPU training')

        if tensorflow_style_init:
            init_as_tensorflow(model)
        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)
        if init_hdf5:
            engine.load_hdf5(init_hdf5,
                             load_weights_keyword=load_weights_keyword)
        if auto_continue:
            assert cfg.init_weights is None
            engine.load_checkpoint(get_last_checkpoint(cfg.output_dir))
        if show_variables:
            engine.show_variables()

        # ------------ do training ---------------------------- #
        engine.log("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch
        last_epoch_done_iters = iteration % iters_per_epoch

        if done_epochs == 0 and last_epoch_done_iters == 0:
            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        recorded_train_time = 0
        recorded_train_examples = 0

        collected_train_loss_sum = 0
        collected_train_loss_count = 0

        if gradient_mask is not None:
            gradient_mask_tensor = {}
            for name, value in gradient_mask.items():
                gradient_mask_tensor[name] = torch.Tensor(value).cuda()
        else:
            gradient_mask_tensor = None

        for epoch in range(done_epochs, cfg.max_epochs):

            if engine.distributed and hasattr(train_data, 'train_sampler'):
                train_data.train_sampler.set_epoch(epoch)

            if epoch == done_epochs:
                pbar = tqdm(range(iters_per_epoch - last_epoch_done_iters))
            else:
                pbar = tqdm(range(iters_per_epoch))

            if epoch == 0 and local_rank == 0:
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str='Init',
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_data,
                                             dataset_name=cfg.dataset_name)

                # load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)

                train_net_time_start = time.time()
                acc, acc5, loss = train_one_step(
                    model,
                    data,
                    label,
                    optimizer,
                    criterion,
                    if_accum_grad,
                    gradient_mask_tensor=gradient_mask_tensor,
                    lasso_keyword_to_strength=lasso_keyword_to_strength)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                for module in model.modules():
                    if hasattr(module, 'set_cur_iter'):
                        module.set_cur_iter(iteration)

                if iteration % cfg.tb_iter_period == 0 and engine.world_rank == 0:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                if epoch >= cfg.max_epochs - COLLECT_TRAIN_LOSS_EPOCHS:
                    collected_train_loss_sum += loss.item()
                    collected_train_loss_count += 1

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                iteration += 1

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and
                                                    engine.world_rank == 0):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            engine.update_iteration(iteration)
            engine.save_latest_ckpt(cfg.output_dir)

            if (epoch + 1) % save_hdf5_epochs == 0:
                engine.save_hdf5(
                    os.path.join(cfg.output_dir,
                                 'epoch-{}.hdf5'.format(epoch)))

            if local_rank == 0 and \
                    cfg.val_epoch_period > 0 and (epoch >= cfg.max_epochs - 10 or epoch % cfg.val_epoch_period == 0):
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str=discrip_str,
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            if iteration >= max_iters:
                break

        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters),
                    int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(
                cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
        if collected_train_loss_count > 0:
            engine.log(
                'TRAIN LOSS collected over last {} epochs: {:.6f}'.format(
                    COLLECT_TRAIN_LOSS_EPOCHS,
                    collected_train_loss_sum / collected_train_loss_count))
Пример #14
0
def csgd_train_and_prune(cfg: BaseConfigByEpoch,
                         target_deps,
                         centri_strength,
                         pacesetter_dict,
                         succeeding_strategy,
                         pruned_weights,
                         net=None,
                         train_dataloader=None,
                         val_dataloader=None,
                         show_variables=False,
                         convbuilder=None,
                         beginning_msg=None,
                         init_hdf5=None,
                         no_l2_keywords=None,
                         use_nesterov=False,
                         tensorflow_style_init=False):

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    clusters_save_path = os.path.join(cfg.output_dir, 'clusters.npy')

    with Engine() as engine:

        is_main_process = (engine.world_rank == 0)  #TODO correct?

        logger = engine.setup_log(name='train',
                                  log_dir=cfg.output_dir,
                                  file_name='log.txt')

        # -- typical model components model, opt,  scheduler,  dataloder --#
        if net is None:
            net = get_model_fn(cfg.dataset_name, cfg.network_type)

        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)

        model = net(cfg, convbuilder).cuda()

        if train_dataloader is None:
            train_dataloader = create_dataset(cfg.dataset_name,
                                              cfg.dataset_subset,
                                              cfg.global_batch_size)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_dataloader = create_dataset(cfg.dataset_name,
                                            'val',
                                            batch_size=100)  #TODO 100?

        print('NOTE: Data prepared')
        print(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))

        optimizer = get_optimizer(cfg, model, use_nesterov=use_nesterov)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()

        # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer,
                              cfg=cfg)

        if engine.distributed:
            print('Distributed training, engine.world_rank={}'.format(
                engine.world_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[engine.world_rank],
                broadcast_buffers=False,
            )
            # model = DistributedDataParallel(model, delay_allreduce=True)
        elif torch.cuda.device_count() > 1:
            print('Single machine multiple GPU training')
            model = torch.nn.parallel.DataParallel(model)

        if tensorflow_style_init:
            for k, v in model.named_parameters():
                if v.dim() in [2, 4]:
                    torch.nn.init.xavier_uniform_(v)
                    print('init {} as xavier_uniform'.format(k))
                if 'bias' in k and 'bn' not in k.lower():
                    torch.nn.init.zeros_(v)
                    print('init {} as zero'.format(k))

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)

        if init_hdf5:
            engine.load_hdf5(init_hdf5)

        kernel_namedvalue_list = engine.get_all_conv_kernel_namedvalue_as_list(
        )

        if os.path.exists(clusters_save_path):
            layer_idx_to_clusters = np.load(clusters_save_path).item()
        else:
            layer_idx_to_clusters = get_layer_idx_to_clusters(
                kernel_namedvalue_list=kernel_namedvalue_list,
                target_deps=target_deps,
                pacesetter_dict=pacesetter_dict)
            if pacesetter_dict is not None:
                for follower_idx, pacesetter_idx in pacesetter_dict.items():
                    if pacesetter_idx in layer_idx_to_clusters:
                        layer_idx_to_clusters[
                            follower_idx] = layer_idx_to_clusters[
                                pacesetter_idx]

            np.save(clusters_save_path, layer_idx_to_clusters)

        csgd_save_file = os.path.join(cfg.output_dir, 'finish.hdf5')

        if os.path.exists(csgd_save_file):
            engine.load_hdf5(csgd_save_file)
        else:
            param_name_to_merge_matrix = generate_merge_matrix_for_kernel(
                deps=cfg.deps,
                layer_idx_to_clusters=layer_idx_to_clusters,
                kernel_namedvalue_list=kernel_namedvalue_list)
            param_name_to_decay_matrix = generate_decay_matrix_for_kernel_and_vecs(
                deps=cfg.deps,
                layer_idx_to_clusters=layer_idx_to_clusters,
                kernel_namedvalue_list=kernel_namedvalue_list,
                weight_decay=cfg.weight_decay,
                centri_strength=centri_strength)
            # if pacesetter_dict is not None:
            #     for follower_idx, pacesetter_idx in pacesetter_dict.items():
            #         follower_kernel_name = kernel_namedvalue_list[follower_idx].name
            #         pacesetter_kernel_name = kernel_namedvalue_list[follower_idx].name
            #         if pacesetter_kernel_name in param_name_to_merge_matrix:
            #             param_name_to_merge_matrix[follower_kernel_name] = param_name_to_merge_matrix[
            #                 pacesetter_kernel_name]
            #             param_name_to_decay_matrix[follower_kernel_name] = param_name_to_decay_matrix[
            #                 pacesetter_kernel_name]

            add_vecs_to_mat_dicts(param_name_to_merge_matrix)

            if show_variables:
                engine.show_variables()

            if beginning_msg:
                engine.log(beginning_msg)

            logger.info("\n\nStart training with pytorch version {}".format(
                torch.__version__))

            iteration = engine.state.iteration
            # done_epochs = iteration // num_train_examples_per_epoch(cfg.dataset_name)
            iters_per_epoch = num_iters_per_epoch(cfg)
            max_iters = iters_per_epoch * cfg.max_epochs
            tb_writer = SummaryWriter(cfg.tb_dir)
            tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

            model.train()

            done_epochs = iteration // iters_per_epoch

            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

            recorded_train_time = 0
            recorded_train_examples = 0

            for epoch in range(done_epochs, cfg.max_epochs):

                pbar = tqdm(range(iters_per_epoch))
                top1 = AvgMeter()
                top5 = AvgMeter()
                losses = AvgMeter()
                discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
                pbar.set_description('Train' + discrip_str)

                if cfg.val_epoch_period > 0 and epoch % cfg.val_epoch_period == 0:
                    model.eval()
                    val_iters = 500 if cfg.dataset_name == 'imagenet' else 100  # use batch_size=100 for val on ImagenNet and CIFAR
                    eval_dict, _ = run_eval(val_dataloader,
                                            val_iters,
                                            model,
                                            criterion,
                                            discrip_str,
                                            dataset_name=cfg.dataset_name)
                    val_top1_value = eval_dict['top1'].item()
                    val_top5_value = eval_dict['top5'].item()
                    val_loss_value = eval_dict['loss'].item()
                    for tag, value in zip(
                            tb_tags,
                        [val_top1_value, val_top5_value, val_loss_value]):
                        tb_writer.add_scalars(tag, {'Val': value}, iteration)
                    engine.log(
                        'validate at epoch {}, top1={:.5f}, top5={:.5f}, loss={:.6f}'
                        .format(epoch, val_top1_value, val_top5_value,
                                val_loss_value))
                    model.train()

                for _ in pbar:

                    start_time = time.time()
                    data, label = load_cuda_data(train_dataloader,
                                                 cfg.dataset_name)
                    data_time = time.time() - start_time

                    train_net_time_start = time.time()
                    acc, acc5, loss = train_one_step(
                        model,
                        data,
                        label,
                        optimizer,
                        criterion,
                        param_name_to_merge_matrix=param_name_to_merge_matrix,
                        param_name_to_decay_matrix=param_name_to_decay_matrix)
                    train_net_time_end = time.time()

                    if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                        recorded_train_examples += cfg.global_batch_size
                        recorded_train_time += train_net_time_end - train_net_time_start

                    scheduler.step()

                    if iteration % cfg.tb_iter_period == 0 and is_main_process:
                        for tag, value in zip(
                                tb_tags,
                            [acc.item(), acc5.item(),
                             loss.item()]):
                            tb_writer.add_scalars(tag, {'Train': value},
                                                  iteration)

                    top1.update(acc.item())
                    top5.update(acc5.item())
                    losses.update(loss.item())

                    pbar_dic = OrderedDict()
                    pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                    pbar_dic['cur_iter'] = iteration
                    pbar_dic['lr'] = scheduler.get_lr()[0]
                    pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                    pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                    pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                    pbar.set_postfix(pbar_dic)

                    if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                        engine.update_iteration(iteration)
                        if (not engine.distributed) or (engine.distributed
                                                        and is_main_process):
                            engine.save_and_link_checkpoint(cfg.output_dir)

                    iteration += 1
                    if iteration >= max_iters:
                        break

                #   do something after an epoch?
                if iteration >= max_iters:
                    break
            #   do something after the training
            if recorded_train_time > 0:
                exp_per_sec = recorded_train_examples / recorded_train_time
            else:
                exp_per_sec = 0
            engine.log(
                'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
                .format(int(TRAIN_SPEED_START * max_iters),
                        int(TRAIN_SPEED_END * max_iters),
                        cfg.global_batch_size, recorded_train_examples,
                        recorded_train_time, exp_per_sec))
            if cfg.save_weights:
                engine.save_checkpoint(cfg.save_weights)
                print('NOTE: training finished, saved to {}'.format(
                    cfg.save_weights))
            engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))

        csgd_prune_and_save(engine=engine,
                            layer_idx_to_clusters=layer_idx_to_clusters,
                            save_file=pruned_weights,
                            succeeding_strategy=succeeding_strategy,
                            new_deps=target_deps)
Пример #15
0
def ding_train(cfg: BaseConfigByEpoch,
               net=None,
               train_dataloader=None,
               val_dataloader=None,
               show_variables=False,
               convbuilder=None):

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    with Engine(cfg) as engine:

        is_main_process = (engine.world_rank == 0)  #TODO correct?

        logger = engine.setup_log(name='train',
                                  log_dir=cfg.output_dir,
                                  file_name='log.txt')

        # -- typical model components model, opt,  scheduler,  dataloder --#
        if net is None:
            net = get_model_fn(cfg.dataset_name, cfg.network_type)

        if convbuilder is None:
            convbuilder = ConvBuilder()

        model = net(cfg, convbuilder).cuda()

        if train_dataloader is None:
            train_dataloader = create_dataset(cfg.dataset_name,
                                              cfg.dataset_subset,
                                              cfg.global_batch_size)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_dataloader = create_dataset(cfg.dataset_name,
                                            'val',
                                            batch_size=100)  #TODO 100?

        print('NOTE: Data prepared')
        print(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))

        # device = torch.device(cfg.device)
        # model.to(device)
        # model.cuda()

        optimizer = get_optimizer(cfg, model)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()

        # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer)

        if engine.distributed:
            print('Distributed training, engine.world_rank={}'.format(
                engine.world_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[engine.world_rank],
                broadcast_buffers=False,
            )
            # model = DistributedDataParallel(model, delay_allreduce=True)

        if engine.continue_state_object:
            engine.restore_checkpoint()
        else:
            if cfg.init_weights:
                engine.load_checkpoint(cfg.init_weights, is_restore=False)

        if show_variables:
            engine.show_variables()

        # ------------ do training ---------------------------- #
        logger.info("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        # done_epochs = iteration // num_train_examples_per_epoch(cfg.dataset_name)
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch

        for epoch in range(done_epochs, cfg.max_epochs):

            pbar = tqdm(range(iters_per_epoch))
            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            if cfg.val_epoch_period > 0 and epoch % cfg.val_epoch_period == 0:
                model.eval()
                val_iters = 500 if cfg.dataset_name == 'imagenet' else 100  # use batch_size=100 for val on ImagenNet and CIFAR
                eval_dict = run_eval(val_dataloader,
                                     val_iters,
                                     model,
                                     criterion,
                                     discrip_str,
                                     dataset_name=cfg.dataset_name)
                val_top1_value = eval_dict['top1'].item()
                val_top5_value = eval_dict['top5'].item()
                val_loss_value = eval_dict['loss'].item()
                for tag, value in zip(
                        tb_tags,
                    [val_top1_value, val_top5_value, val_loss_value]):
                    tb_writer.add_scalars(tag, {'Val': value}, iteration)
                engine.log(
                    'validate at epoch {}, top1={:.5f}, top5={:.5f}, loss={:.6f}'
                    .format(epoch, val_top1_value, val_top5_value,
                            val_loss_value))
                model.train()

            for _ in pbar:

                scheduler.step()

                start_time = time.time()
                data, label = load_cuda_data(train_dataloader,
                                             cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)
                acc, acc5, loss = train_one_step(model, data, label, optimizer,
                                                 criterion, if_accum_grad)

                if iteration % cfg.tb_iter_period == 0 and is_main_process:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed
                                                    and is_main_process):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                iteration += 1
                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            if iteration >= max_iters:
                break
        #   do something after the training
        engine.save_checkpoint(cfg.save_weights)
        print('NOTE: training finished, saved to {}'.format(cfg.save_weights))
Пример #16
0
def aofp_train_main(local_rank,
                    target_layers,
                    succ_strategy,
                    warmup_iterations,
                    aofp_batches_per_half,
                    flops_func,
                    cfg: BaseConfigByEpoch,
                    net=None,
                    train_dataloader=None,
                    val_dataloader=None,
                    show_variables=False,
                    convbuilder=None,
                    init_hdf5=None,
                    no_l2_keywords='depth',
                    gradient_mask=None,
                    use_nesterov=False,
                    tensorflow_style_init=False,
                    keyword_to_lr_mult=None,
                    auto_continue=False,
                    lasso_keyword_to_strength=None,
                    save_hdf5_epochs=10000,
                    remain_flops_ratio=0):

    if no_l2_keywords is None:
        no_l2_keywords = []
    if type(no_l2_keywords) is not list:
        no_l2_keywords = [no_l2_keywords]

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    with Engine(local_rank=local_rank) as engine:
        engine.setup_log(name='train',
                         log_dir=cfg.output_dir,
                         file_name='log.txt')

        # ----------------------------- build model ------------------------------
        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)
        if net is None:
            net_fn = get_model_fn(cfg.dataset_name, cfg.network_type)
            model = net_fn(cfg, convbuilder)
        else:
            model = net
        model = model.cuda()
        # ----------------------------- model done ------------------------------

        # ---------------------------- prepare data -------------------------
        if train_dataloader is None:
            train_data = create_dataset(cfg.dataset_name,
                                        cfg.dataset_subset,
                                        cfg.global_batch_size,
                                        distributed=engine.distributed)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_data = create_dataset(cfg.dataset_name,
                                      'val',
                                      global_batch_size=100,
                                      distributed=False)
        engine.echo('NOTE: Data prepared')
        engine.echo(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))
        # ----------------------------- data done --------------------------------

        # ------------------------ parepare optimizer, scheduler, criterion -------
        optimizer = get_optimizer(engine,
                                  cfg,
                                  model,
                                  no_l2_keywords=no_l2_keywords,
                                  use_nesterov=use_nesterov,
                                  keyword_to_lr_mult=keyword_to_lr_mult)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()
        # --------------------------------- done -------------------------------

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer)

        if engine.distributed:
            torch.cuda.set_device(local_rank)
            engine.echo('Distributed training, device {}'.format(local_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                broadcast_buffers=False,
            )
        else:
            assert torch.cuda.device_count() == 1
            engine.echo('Single GPU training')

        if tensorflow_style_init:
            init_as_tensorflow(model)
        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)
        if init_hdf5:
            engine.load_part('base_path.', init_hdf5)
        if auto_continue:
            assert cfg.init_weights is None
            engine.load_checkpoint(get_last_checkpoint(cfg.output_dir))
        if show_variables:
            engine.show_variables()

        # ------------ do training ---------------------------- #
        engine.log("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch
        last_epoch_done_iters = iteration % iters_per_epoch

        if done_epochs == 0 and last_epoch_done_iters == 0:
            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        recorded_train_time = 0
        recorded_train_examples = 0

        collected_train_loss_sum = 0
        collected_train_loss_count = 0

        if gradient_mask is not None:
            gradient_mask_tensor = {}
            for name, value in gradient_mask.items():
                gradient_mask_tensor[name] = torch.Tensor(value).cuda()
        else:
            gradient_mask_tensor = None

        #########################   aofp
        _init_interval = aofp_batches_per_half // len(target_layers)
        layer_to_start_iter = {
            i: (_init_interval * i + warmup_iterations)
            for i in target_layers
        }
        print(
            'the initial layer_to_start_iter = {}'.format(layer_to_start_iter))
        #   0.  get all the AOFPLayers
        layer_idx_to_module = {}
        for submodule in model.modules():
            if hasattr(submodule, 'score_mask') or hasattr(
                    submodule, 't_value'):
                layer_idx_to_module[submodule.conv_idx] = submodule
        print(layer_idx_to_module)
        ######################################

        for epoch in range(done_epochs, cfg.max_epochs):

            if engine.distributed and hasattr(train_data, 'train_sampler'):
                train_data.train_sampler.set_epoch(epoch)

            if epoch == done_epochs:
                pbar = tqdm(range(iters_per_epoch - last_epoch_done_iters))
            else:
                pbar = tqdm(range(iters_per_epoch))

            if epoch == 0 and local_rank == 0:
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str='Init',
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_data,
                                             dataset_name=cfg.dataset_name)

                # load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)

                train_net_time_start = time.time()

                ############    aofp
                #   1.  see if it is time to start on every layer
                #   2.  forward and accumulate
                #   3.  if a half on some layer is finished, do something
                #   ----    fetch its accumulated t vectors, analyze the first 'granu' elements
                #   ----    if good enough, set the base mask, reset the search space
                #   ----    elif granu == 1, do nothing
                #   ----    else, granu /= 2, reset the search space
                for layer_idx, start_iter in layer_to_start_iter.items():
                    if start_iter == iteration:
                        layer_idx_to_module[layer_idx].start_aofp(iteration)
                acc, acc5, loss = train_one_step(
                    model,
                    data,
                    label,
                    optimizer,
                    criterion,
                    if_accum_grad,
                    gradient_mask_tensor=gradient_mask_tensor,
                    lasso_keyword_to_strength=lasso_keyword_to_strength)
                for layer_idx, aofp_layer in layer_idx_to_module.items():
                    #   accumulate
                    if layer_idx not in succ_strategy:
                        continue
                    follow_layer_idx = succ_strategy[layer_idx]
                    if follow_layer_idx not in layer_idx_to_module:
                        continue
                    t_value = layer_idx_to_module[follow_layer_idx].t_value
                    aofp_layer.accumulate_t_value(t_value)
                    if aofp_layer.finished_a_half(iteration):
                        aofp_layer.halve_or_stop(iteration)
                ###################################

                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                for module in model.modules():
                    if hasattr(module, 'set_cur_iter'):
                        module.set_cur_iter(iteration)

                if iteration % cfg.tb_iter_period == 0 and engine.world_rank == 0:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                if epoch >= cfg.max_epochs - COLLECT_TRAIN_LOSS_EPOCHS:
                    collected_train_loss_sum += loss.item()
                    collected_train_loss_count += 1

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                iteration += 1

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and
                                                    engine.world_rank == 0):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            engine.update_iteration(iteration)
            engine.save_latest_ckpt(cfg.output_dir)

            if (epoch + 1) % save_hdf5_epochs == 0:
                engine.save_hdf5(
                    os.path.join(cfg.output_dir,
                                 'epoch-{}.hdf5'.format(epoch)))

            if local_rank == 0 and \
                    cfg.val_epoch_period > 0 and (epoch >= cfg.max_epochs - 10 or epoch % cfg.val_epoch_period == 0):
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str=discrip_str,
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            cur_deps = np.array(cfg.deps)
            for submodule in model.modules():
                if hasattr(submodule, 'base_mask'):
                    cur_deps[submodule.conv_idx] = np.sum(
                        submodule.base_mask.cpu().numpy() == 1)
            origin_flops = flops_func(cfg.deps)
            cur_flops = flops_func(cur_deps)
            remain_ratio = cur_flops / origin_flops
            if local_rank == 0:
                print('##########################')
                print('origin deps ', cfg.deps)
                print('cur deps ', cur_deps)
                print('remain flops ratio = ', remain_ratio, 'the target is ',
                      remain_flops_ratio)
                print('##########################')
            if remain_ratio < remain_flops_ratio:
                break
            if iteration >= max_iters:
                break

        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters),
                    int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(
                cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
        if collected_train_loss_count > 0:
            engine.log(
                'TRAIN LOSS collected over last {} epochs: {:.6f}'.format(
                    COLLECT_TRAIN_LOSS_EPOCHS,
                    collected_train_loss_sum / collected_train_loss_count))

        final_deps = aofp_prune(model,
                                origin_deps=cfg.deps,
                                succ_strategy=succ_strategy,
                                save_path=os.path.join(cfg.output_dir,
                                                       'finish_pruned.hdf5'))
        origin_flops = flops_func(cfg.deps)
        cur_flops = flops_func(final_deps)
        engine.log(
            '##################################################################'
        )
        engine.log(cfg.network_type)
        engine.log('origin width: {} , flops {} '.format(
            cfg.deps, origin_flops))
        engine.log('final width: {}, flops {} '.format(final_deps, cur_flops))
        engine.log('flops reduction: {}'.format(1 - cur_flops / origin_flops))
        return final_deps