示例#1
0
def main(args):

    train_generator, val_generator = get_generator(args)
    model = get_model(args, train_generator.num_class)
    loss_object = tf.keras.losses.CategoricalCrossentropy()
    optimizer = get_optimizers(args)
    lr_scheduler = get_lr_scheduler(args)
    if not os.path.exists(args.checkpoints):
        os.makedirs(args.checkpoints)
    # lr_cb = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=20, verbose=1, min_lr=0)
    lr_cb = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
    model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=args.checkpoints +
        '/best_weight_{epoch}_{accuracy:.3f}_{val_accuracy:.3f}.h5',
        monitor='val_accuracy',
        mode='max',
        verbose=1,
        save_best_only=True,
        save_weights_only=True)
    cbs = [lr_cb, model_checkpoint_cb]
    model.compile(
        optimizer,
        loss_object,
        metrics=["accuracy"],
    )
    model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=args.epochs,
        callbacks=cbs,
        verbose=1,
    )
def main(args):
    #create dataset
    train_generator, val_dataset, pred_generator = get_generator(args)
    #create model
    model = get_model(args,training=True)
    #create loss
    loss_fun = get_loss(args)
    #create learning rate scheduler
    lr_scheduler = get_lr_scheduler(args)
    #create optimizer
    optimizer = get_optimizers(args)
    best_weight_path = ''
    #tensorboard
    open_tensorboard_url = False
    os.system('rm -rf ./logs/')
    tb = program.TensorBoard()
    tb.configure(argv=[None, '--logdir', 'logs','--reload_interval','15'])
    url = tb.launch()
    print("Tensorboard engine is running at {}".format(url))
    if args.train_mode == 'fit':
        if open_tensorboard_url:
            webbrowser.open(url, new=1)
        mAP_writer = tf.summary.create_file_writer("logs/mAP")
        coco_map_callback = CocoMapCallback(pred_generator,model,args,mAP_writer)
        callbacks = [
            tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
            coco_map_callback,
            # ReduceLROnPlateau(verbose=1),
            # EarlyStopping(patience=3, verbose=1),
            TensorBoard(log_dir='logs')
        ]
        model.compile(optimizer=optimizer,loss=loss_fun,run_eagerly=False)
        model.fit(train_generator,epochs=args.epochs,
                            callbacks=callbacks,
                            # validation_data=val_dataset,
                            max_queue_size=10,
                            workers=8,
                            use_multiprocessing=False
                            )
        best_weight_path = coco_map_callback.best_weight_path
    else:
        print("loading dataset...")
        start_time = time.perf_counter()
        coco_map = EagerCocoMap(pred_generator, model, args)
        max_coco_map = -1
        max_coco_map_epoch = -1
        accumulate_num = args.accumulated_gradient_num
        accumulate_index = 0
        accum_gradient = [tf.Variable(tf.zeros_like(this_var)) for this_var in model.trainable_variables]

        train_writer = tf.summary.create_file_writer("logs/train")
        mAP_writer = tf.summary.create_file_writer("logs/mAP")

        for epoch in range(int(args.epochs)):
            lr = lr_scheduler(epoch)
            optimizer.learning_rate.assign(lr)
            remaining_epoches = args.epochs - epoch - 1
            epoch_start_time = time.perf_counter()
            train_loss = 0
            train_generator_tqdm = tqdm(enumerate(train_generator), total=len(train_generator))
            for batch_index, (batch_imgs, batch_labels)  in train_generator_tqdm:
                s1 = time.time()
                if args.model_name == "efficientdet":
                    with tf.GradientTape() as tape:
                        model_outputs = model(batch_imgs, training=True)
                        num_level = args.max_level - args.min_level + 1
                        cls_loss,box_loss = 0,0
                        for level in range(num_level):
                            cls_loss += loss_fun[0][level](batch_labels[0][level],model_outputs[0][level])
                            box_loss += loss_fun[1][level](batch_labels[1][level], model_outputs[1][level])
                        data_loss = cls_loss+box_loss
                        # data_loss = loss_fun(batch_labels,model_outputs)

                        total_loss = data_loss + args.weight_decay * tf.add_n(
                            [tf.nn.l2_loss(v) for v in model.trainable_variables if
                             'batch_normalization' not in v.name])
                else:
                    raise ValueError('unsupported model type {}'.format(args.model_name))

                grads = tape.gradient(total_loss, model.trainable_variables)
                accum_gradient = [acum_grad.assign_add(grad) for acum_grad, grad in zip(accum_gradient, grads)]

                accumulate_index += 1
                if accumulate_index == accumulate_num:
                    optimizer.apply_gradients(zip(accum_gradient, model.trainable_variables))
                    accum_gradient = [ grad.assign_sub(grad) for grad in accum_gradient]
                    accumulate_index = 0
                train_loss += total_loss
                train_generator_tqdm.set_description(
                    "epoch:{}/{},train_loss:{:.4f},lr:{:.6f}".format(epoch, args.epochs,
                                                                                     train_loss/(batch_index+1),
                                                                                     optimizer.learning_rate.numpy()))
            train_generator.on_epoch_end()

            with train_writer.as_default():
                tf.summary.scalar("train_loss", train_loss/len(train_generator), step=epoch)
                train_writer.flush()

            #evaluation
            if epoch >= args.start_eval_epoch:
                if epoch % args.eval_epoch_interval == 0:
                    summary_metrics = coco_map.eval()
                    if summary_metrics['Precision/[email protected]'] > max_coco_map:
                        max_coco_map = summary_metrics['Precision/[email protected]']
                        max_coco_map_epoch = epoch
                        best_weight_path = os.path.join(args.checkpoints_dir, 'best_weight_{}_{}_{:.3f}'.format(args.model_name+"_"+args.model_type,max_coco_map_epoch, max_coco_map))
                        model.save_weights(best_weight_path)

                    print("max_coco_map:{},epoch:{}".format(max_coco_map,max_coco_map_epoch))
                    with mAP_writer.as_default():
                        tf.summary.scalar("[email protected]", summary_metrics['Precision/[email protected]'], step=epoch)
                        mAP_writer.flush()

            cur_time = time.perf_counter()
            one_epoch_time = cur_time - epoch_start_time
            print("time elapsed: {:.3f} hour, time left: {:.3f} hour".format((cur_time-start_time)/3600,remaining_epoches*one_epoch_time/3600))

            if epoch>0 and not open_tensorboard_url:
                open_tensorboard_url = True
                webbrowser.open(url,new=1)

    print("Training is finished!")
    #save model
    print("Exporting model...")
    if args.export_dir and best_weight_path:
        tf.keras.backend.clear_session()
        pred_model = get_model(args, training=False)
        pred_model.load_weights(best_weight_path)
        best_model_path = os.path.join(args.export_dir,best_weight_path.split('/')[-1].replace('weight','model'),'1')
        tf.saved_model.save(pred_model, best_model_path)
示例#3
0
trainloader, testloader = loader(LOADERNAME)
beg = time.time()
net = get_network(MODELNAME, NUMCLASS)
if LOADING:
    net.load_state_dict(torch.load(LOADPATH, map_location='cpu'))
net.cuda()
if torch.cuda.device_count() > 1:
    net = DataParallel(net)
end = time.time()
print('[*]! model load time is{}'.format(end - beg))
iters = len(trainloader)
optimizer = torch.optim.SGD(net.parameters(),
                            lr=INITLR,
                            momentum=0.9,
                            weight_decay=WD)
scheduler = get_lr_scheduler(optimizer, LRTAG)
warmup = WarmUpLR(optimizer, iters * WARM)

print('[*] train start !!!!!!!!!!!')
for epoch in range(EPOCHS):
    net.train()
    train_loss = 0
    total = 0
    best_acc = 0
    best_epoch = 0
    for i, data in enumerate(trainloader):
        img, label = data[0].cuda(), data[1].cuda()
        batch_size = img.size(0)
        optimizer.zero_grad()
        if MIXUP:
            img, labela, labelb, lam = mixup_data(img, label)
def csgd_train_main(local_rank,
                    cfg: BaseConfigByEpoch,
                    target_deps,
                    succeeding_strategy,
                    pacesetter_dict,
                    centri_strength,
                    pruned_weights,
                    net=None,
                    train_dataloader=None,
                    val_dataloader=None,
                    show_variables=False,
                    convbuilder=None,
                    init_hdf5=None,
                    no_l2_keywords='depth',
                    use_nesterov=False,
                    load_weights_keyword=None,
                    keyword_to_lr_mult=None,
                    auto_continue=False,
                    save_hdf5_epochs=10000):

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    clusters_save_path = os.path.join(cfg.output_dir, 'clusters.npy')

    with Engine(local_rank=local_rank) as engine:
        engine.setup_log(name='train',
                         log_dir=cfg.output_dir,
                         file_name='log.txt')

        # ----------------------------- build model ------------------------------
        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)
        if net is None:
            net_fn = get_model_fn(cfg.dataset_name, cfg.network_type)
            model = net_fn(cfg, convbuilder)
        else:
            model = net
        model = model.cuda()
        # ----------------------------- model done ------------------------------

        # ---------------------------- prepare data -------------------------
        if train_dataloader is None:
            train_data = create_dataset(cfg.dataset_name,
                                        cfg.dataset_subset,
                                        cfg.global_batch_size,
                                        distributed=engine.distributed)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_data = create_dataset(cfg.dataset_name,
                                      'val',
                                      global_batch_size=100,
                                      distributed=False)
        engine.echo('NOTE: Data prepared')
        engine.echo(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))
        # ----------------------------- data done --------------------------------

        # ------------------------ parepare optimizer, scheduler, criterion -------
        if no_l2_keywords is None:
            no_l2_keywords = []
        if type(no_l2_keywords) is not list:
            no_l2_keywords = [no_l2_keywords]
        # For a target parameter, cancel its weight decay in optimizer, because the weight decay will be later encoded in the decay mat
        conv_idx = 0
        for k, v in model.named_parameters():
            if v.dim() != 4:
                continue
            print('prune {} from {} to {}'.format(conv_idx,
                                                  target_deps[conv_idx],
                                                  cfg.deps[conv_idx]))
            if target_deps[conv_idx] < cfg.deps[conv_idx]:
                no_l2_keywords.append(k.replace(KERNEL_KEYWORD, 'conv'))
                no_l2_keywords.append(k.replace(KERNEL_KEYWORD, 'bn'))
            conv_idx += 1
        print('no l2: ', no_l2_keywords)
        optimizer = get_optimizer(engine,
                                  cfg,
                                  model,
                                  no_l2_keywords=no_l2_keywords,
                                  use_nesterov=use_nesterov,
                                  keyword_to_lr_mult=keyword_to_lr_mult)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()
        # --------------------------------- done -------------------------------

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer)

        if engine.distributed:
            torch.cuda.set_device(local_rank)
            engine.echo('Distributed training, device {}'.format(local_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                broadcast_buffers=False,
            )
        else:
            assert torch.cuda.device_count() == 1
            engine.echo('Single GPU training')

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)
        if init_hdf5:
            engine.load_hdf5(init_hdf5,
                             load_weights_keyword=load_weights_keyword)
        if auto_continue:
            assert cfg.init_weights is None
            engine.load_checkpoint(get_last_checkpoint(cfg.output_dir))
        if show_variables:
            engine.show_variables()

        #   ===================================== prepare the clusters and matrices for C-SGD ==========
        kernel_namedvalue_list = engine.get_all_conv_kernel_namedvalue_as_list(
        )

        if os.path.exists(clusters_save_path):
            layer_idx_to_clusters = np.load(clusters_save_path,
                                            allow_pickle=True).item()
        else:
            if local_rank == 0:
                layer_idx_to_clusters = get_layer_idx_to_clusters(
                    kernel_namedvalue_list=kernel_namedvalue_list,
                    target_deps=target_deps,
                    pacesetter_dict=pacesetter_dict)
                if pacesetter_dict is not None:
                    for follower_idx, pacesetter_idx in pacesetter_dict.items(
                    ):
                        if pacesetter_idx in layer_idx_to_clusters:
                            layer_idx_to_clusters[
                                follower_idx] = layer_idx_to_clusters[
                                    pacesetter_idx]

                np.save(clusters_save_path, layer_idx_to_clusters)
            else:
                while not os.path.exists(clusters_save_path):
                    time.sleep(10)
                    print('sleep, waiting for process 0 to calculate clusters')
                layer_idx_to_clusters = np.load(clusters_save_path,
                                                allow_pickle=True).item()

        param_name_to_merge_matrix = generate_merge_matrix_for_kernel(
            deps=cfg.deps,
            layer_idx_to_clusters=layer_idx_to_clusters,
            kernel_namedvalue_list=kernel_namedvalue_list)
        add_vecs_to_merge_mat_dicts(param_name_to_merge_matrix)
        param_name_to_decay_matrix = generate_decay_matrix_for_kernel_and_vecs(
            deps=cfg.deps,
            layer_idx_to_clusters=layer_idx_to_clusters,
            kernel_namedvalue_list=kernel_namedvalue_list,
            weight_decay=cfg.weight_decay,
            weight_decay_bias=cfg.weight_decay_bias,
            centri_strength=centri_strength)
        print(param_name_to_decay_matrix.keys())
        print(param_name_to_merge_matrix.keys())

        conv_idx = 0
        param_to_clusters = {}
        for k, v in model.named_parameters():
            if v.dim() != 4:
                continue
            if conv_idx in layer_idx_to_clusters:
                for clsts in layer_idx_to_clusters[conv_idx]:
                    if len(clsts) > 1:
                        param_to_clusters[v] = layer_idx_to_clusters[conv_idx]
                        break
            conv_idx += 1
        #   ============================================================================================

        # ------------ do training ---------------------------- #
        engine.log("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch
        last_epoch_done_iters = iteration % iters_per_epoch

        if done_epochs == 0 and last_epoch_done_iters == 0:
            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        recorded_train_time = 0
        recorded_train_examples = 0

        collected_train_loss_sum = 0
        collected_train_loss_count = 0

        for epoch in range(done_epochs, cfg.max_epochs):

            if engine.distributed and hasattr(train_data, 'train_sampler'):
                train_data.train_sampler.set_epoch(epoch)

            if epoch == done_epochs:
                pbar = tqdm(range(iters_per_epoch - last_epoch_done_iters))
            else:
                pbar = tqdm(range(iters_per_epoch))

            if epoch == 0 and local_rank == 0:
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str='Init',
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_data,
                                             dataset_name=cfg.dataset_name)

                # load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                train_net_time_start = time.time()
                acc, acc5, loss = train_one_step(
                    model,
                    data,
                    label,
                    optimizer,
                    criterion,
                    param_name_to_merge_matrix=param_name_to_merge_matrix,
                    param_name_to_decay_matrix=param_name_to_decay_matrix)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                for module in model.modules():
                    if hasattr(module, 'set_cur_iter'):
                        module.set_cur_iter(iteration)

                if iteration % cfg.tb_iter_period == 0 and engine.world_rank == 0:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)
                    deviation_sum = 0
                    for param, clusters in param_to_clusters.items():
                        pvalue = param.detach().cpu().numpy()
                        for cl in clusters:
                            if len(cl) == 1:
                                continue
                            selected = pvalue[cl, :, :, :]
                            mean_kernel = np.mean(selected,
                                                  axis=0,
                                                  keepdims=True)
                            diff = selected - mean_kernel
                            deviation_sum += np.sum(diff**2)
                    tb_writer.add_scalars('deviation_sum',
                                          {'Train': deviation_sum}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                if epoch >= cfg.max_epochs - COLLECT_TRAIN_LOSS_EPOCHS:
                    collected_train_loss_sum += loss.item()
                    collected_train_loss_count += 1

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                iteration += 1

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and
                                                    engine.world_rank == 0):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            engine.update_iteration(iteration)
            engine.save_latest_ckpt(cfg.output_dir)

            if (epoch + 1) % save_hdf5_epochs == 0:
                engine.save_hdf5(
                    os.path.join(cfg.output_dir,
                                 'epoch-{}.hdf5'.format(epoch)))

            if local_rank == 0 and \
                    cfg.val_epoch_period > 0 and (epoch >= cfg.max_epochs - 10 or epoch % cfg.val_epoch_period == 0):
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str=discrip_str,
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            if iteration >= max_iters:
                break

        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters),
                    int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(
                cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
        if collected_train_loss_count > 0:
            engine.log(
                'TRAIN LOSS collected over last {} epochs: {:.6f}'.format(
                    COLLECT_TRAIN_LOSS_EPOCHS,
                    collected_train_loss_sum / collected_train_loss_count))

    if local_rank == 0:
        csgd_prune_and_save(engine=engine,
                            layer_idx_to_clusters=layer_idx_to_clusters,
                            save_file=pruned_weights,
                            succeeding_strategy=succeeding_strategy,
                            new_deps=target_deps)
示例#5
0
def ding_train(cfg: BaseConfigByEpoch,
               net=None,
               train_dataloader=None,
               val_dataloader=None,
               show_variables=False,
               convbuilder=None,
               beginning_msg=None,
               init_hdf5=None,
               no_l2_keywords=None,
               gradient_mask=None,
               use_nesterov=False,
               tensorflow_style_init=False):

    # LOCAL_RANK = 0
    #
    # num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    # is_distributed = num_gpus > 1
    #
    # if is_distributed:
    #     torch.cuda.set_device(LOCAL_RANK)
    #     torch.distributed.init_process_group(
    #         backend="nccl", init_method="env://"
    #     )
    #     synchronize()
    #
    # torch.backends.cudnn.benchmark = True

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    with Engine() as engine:

        is_main_process = (engine.world_rank == 0)  #TODO correct?

        logger = engine.setup_log(name='train',
                                  log_dir=cfg.output_dir,
                                  file_name='log.txt')

        # -- typical model components model, opt,  scheduler,  dataloder --#
        if net is None:
            net = get_model_fn(cfg.dataset_name, cfg.network_type)

        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)

        model = net(cfg, convbuilder).cuda()

        if train_dataloader is None:
            train_dataloader = create_dataset(cfg.dataset_name,
                                              cfg.dataset_subset,
                                              cfg.global_batch_size)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_dataloader = create_dataset(cfg.dataset_name,
                                            'val',
                                            batch_size=100)  #TODO 100?

        print('NOTE: Data prepared')
        print(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))

        if no_l2_keywords is None:
            no_l2_keywords = []

        optimizer = get_optimizer(cfg,
                                  model,
                                  no_l2_keywords=no_l2_keywords,
                                  use_nesterov=use_nesterov)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()

        # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer,
                              cfg=cfg)

        if engine.distributed:
            print('Distributed training, engine.world_rank={}'.format(
                engine.world_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[engine.world_rank],
                broadcast_buffers=False,
            )
            # model = DistributedDataParallel(model, delay_allreduce=True)
        elif torch.cuda.device_count() > 1:
            print('Single machine multiple GPU training')
            model = torch.nn.parallel.DataParallel(model)

        if tensorflow_style_init:
            for k, v in model.named_parameters():
                if v.dim() in [2, 4]:
                    torch.nn.init.xavier_uniform_(v)
                    print('init {} as xavier_uniform'.format(k))
                if 'bias' in k and 'bn' not in k.lower():
                    torch.nn.init.zeros_(v)
                    print('init {} as zero'.format(k))

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)

        if init_hdf5:
            engine.load_hdf5(init_hdf5)

        if show_variables:
            engine.show_variables()

        # ------------ do training ---------------------------- #
        if beginning_msg:
            engine.log(beginning_msg)
        logger.info("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        # done_epochs = iteration // num_train_examples_per_epoch(cfg.dataset_name)
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch

        engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        recorded_train_time = 0
        recorded_train_examples = 0

        collected_train_loss_sum = 0
        collected_train_loss_count = 0

        if gradient_mask is not None:
            gradient_mask_tensor = {}
            for name, value in gradient_mask.items():
                gradient_mask_tensor[name] = torch.Tensor(value).cuda()
        else:
            gradient_mask_tensor = None

        for epoch in range(done_epochs, cfg.max_epochs):

            pbar = tqdm(range(iters_per_epoch))
            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            if cfg.val_epoch_period > 0 and epoch % cfg.val_epoch_period == 0:
                model.eval()
                val_iters = 500 if cfg.dataset_name == 'imagenet' else 100  # use batch_size=100 for val on ImagenNet and CIFAR
                eval_dict, _ = run_eval(val_dataloader,
                                        val_iters,
                                        model,
                                        criterion,
                                        discrip_str,
                                        dataset_name=cfg.dataset_name)
                val_top1_value = eval_dict['top1'].item()
                val_top5_value = eval_dict['top5'].item()
                val_loss_value = eval_dict['loss'].item()
                for tag, value in zip(
                        tb_tags,
                    [val_top1_value, val_top5_value, val_loss_value]):
                    tb_writer.add_scalars(tag, {'Val': value}, iteration)
                engine.log(
                    'validate at epoch {}, top1={:.5f}, top5={:.5f}, loss={:.6f}'
                    .format(epoch, val_top1_value, val_top5_value,
                            val_loss_value))
                model.train()

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_dataloader,
                                             cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)

                train_net_time_start = time.time()
                acc, acc5, loss = train_one_step(
                    model,
                    data,
                    label,
                    optimizer,
                    criterion,
                    if_accum_grad,
                    gradient_mask_tensor=gradient_mask_tensor)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                if iteration % cfg.tb_iter_period == 0 and is_main_process:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                if epoch >= cfg.max_epochs - COLLECT_TRAIN_LOSS_EPOCHS:
                    collected_train_loss_sum += loss.item()
                    collected_train_loss_count += 1

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed
                                                    and is_main_process):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                iteration += 1
                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            if iteration >= max_iters:
                break
        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters),
                    int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(
                cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
        engine.log('TRAIN LOSS collected over last {} epochs: {:.6f}'.format(
            COLLECT_TRAIN_LOSS_EPOCHS,
            collected_train_loss_sum / collected_train_loss_count))
示例#6
0
def train_main(local_rank,
               cfg: BaseConfigByEpoch,
               net=None,
               train_dataloader=None,
               val_dataloader=None,
               show_variables=False,
               convbuilder=None,
               init_hdf5=None,
               no_l2_keywords='depth',
               gradient_mask=None,
               use_nesterov=False,
               tensorflow_style_init=False,
               load_weights_keyword=None,
               keyword_to_lr_mult=None,
               auto_continue=False,
               lasso_keyword_to_strength=None,
               save_hdf5_epochs=10000):

    if no_l2_keywords is None:
        no_l2_keywords = []
    if type(no_l2_keywords) is not list:
        no_l2_keywords = [no_l2_keywords]

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    with Engine(local_rank=local_rank) as engine:
        engine.setup_log(name='train',
                         log_dir=cfg.output_dir,
                         file_name='log.txt')

        # ----------------------------- build model ------------------------------
        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)
        if net is None:
            net_fn = get_model_fn(cfg.dataset_name, cfg.network_type)
            model = net_fn(cfg, convbuilder)
        else:
            model = net
        model = model.cuda()
        # ----------------------------- model done ------------------------------

        # ---------------------------- prepare data -------------------------
        if train_dataloader is None:
            train_data = create_dataset(cfg.dataset_name,
                                        cfg.dataset_subset,
                                        cfg.global_batch_size,
                                        distributed=engine.distributed)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_data = create_dataset(cfg.dataset_name,
                                      'val',
                                      global_batch_size=100,
                                      distributed=False)
        engine.echo('NOTE: Data prepared')
        engine.echo(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))
        # ----------------------------- data done --------------------------------

        # ------------------------ parepare optimizer, scheduler, criterion -------
        optimizer = get_optimizer(engine,
                                  cfg,
                                  model,
                                  no_l2_keywords=no_l2_keywords,
                                  use_nesterov=use_nesterov,
                                  keyword_to_lr_mult=keyword_to_lr_mult)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()
        # --------------------------------- done -------------------------------

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer)

        if engine.distributed:
            torch.cuda.set_device(local_rank)
            engine.echo('Distributed training, device {}'.format(local_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                broadcast_buffers=False,
            )
        else:
            assert torch.cuda.device_count() == 1
            engine.echo('Single GPU training')

        if tensorflow_style_init:
            init_as_tensorflow(model)
        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)
        if init_hdf5:
            engine.load_hdf5(init_hdf5,
                             load_weights_keyword=load_weights_keyword)
        if auto_continue:
            assert cfg.init_weights is None
            engine.load_checkpoint(get_last_checkpoint(cfg.output_dir))
        if show_variables:
            engine.show_variables()

        # ------------ do training ---------------------------- #
        engine.log("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch
        last_epoch_done_iters = iteration % iters_per_epoch

        if done_epochs == 0 and last_epoch_done_iters == 0:
            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        recorded_train_time = 0
        recorded_train_examples = 0

        collected_train_loss_sum = 0
        collected_train_loss_count = 0

        if gradient_mask is not None:
            gradient_mask_tensor = {}
            for name, value in gradient_mask.items():
                gradient_mask_tensor[name] = torch.Tensor(value).cuda()
        else:
            gradient_mask_tensor = None

        for epoch in range(done_epochs, cfg.max_epochs):

            if engine.distributed and hasattr(train_data, 'train_sampler'):
                train_data.train_sampler.set_epoch(epoch)

            if epoch == done_epochs:
                pbar = tqdm(range(iters_per_epoch - last_epoch_done_iters))
            else:
                pbar = tqdm(range(iters_per_epoch))

            if epoch == 0 and local_rank == 0:
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str='Init',
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_data,
                                             dataset_name=cfg.dataset_name)

                # load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)

                train_net_time_start = time.time()
                acc, acc5, loss = train_one_step(
                    model,
                    data,
                    label,
                    optimizer,
                    criterion,
                    if_accum_grad,
                    gradient_mask_tensor=gradient_mask_tensor,
                    lasso_keyword_to_strength=lasso_keyword_to_strength)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                for module in model.modules():
                    if hasattr(module, 'set_cur_iter'):
                        module.set_cur_iter(iteration)

                if iteration % cfg.tb_iter_period == 0 and engine.world_rank == 0:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                if epoch >= cfg.max_epochs - COLLECT_TRAIN_LOSS_EPOCHS:
                    collected_train_loss_sum += loss.item()
                    collected_train_loss_count += 1

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                iteration += 1

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and
                                                    engine.world_rank == 0):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            engine.update_iteration(iteration)
            engine.save_latest_ckpt(cfg.output_dir)

            if (epoch + 1) % save_hdf5_epochs == 0:
                engine.save_hdf5(
                    os.path.join(cfg.output_dir,
                                 'epoch-{}.hdf5'.format(epoch)))

            if local_rank == 0 and \
                    cfg.val_epoch_period > 0 and (epoch >= cfg.max_epochs - 10 or epoch % cfg.val_epoch_period == 0):
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str=discrip_str,
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            if iteration >= max_iters:
                break

        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters),
                    int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(
                cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
        if collected_train_loss_count > 0:
            engine.log(
                'TRAIN LOSS collected over last {} epochs: {:.6f}'.format(
                    COLLECT_TRAIN_LOSS_EPOCHS,
                    collected_train_loss_sum / collected_train_loss_count))
示例#7
0
def csgd_train_and_prune(cfg: BaseConfigByEpoch,
                         target_deps,
                         centri_strength,
                         pacesetter_dict,
                         succeeding_strategy,
                         pruned_weights,
                         net=None,
                         train_dataloader=None,
                         val_dataloader=None,
                         show_variables=False,
                         convbuilder=None,
                         beginning_msg=None,
                         init_hdf5=None,
                         no_l2_keywords=None,
                         use_nesterov=False,
                         tensorflow_style_init=False):

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    clusters_save_path = os.path.join(cfg.output_dir, 'clusters.npy')

    with Engine() as engine:

        is_main_process = (engine.world_rank == 0)  #TODO correct?

        logger = engine.setup_log(name='train',
                                  log_dir=cfg.output_dir,
                                  file_name='log.txt')

        # -- typical model components model, opt,  scheduler,  dataloder --#
        if net is None:
            net = get_model_fn(cfg.dataset_name, cfg.network_type)

        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)

        model = net(cfg, convbuilder).cuda()

        if train_dataloader is None:
            train_dataloader = create_dataset(cfg.dataset_name,
                                              cfg.dataset_subset,
                                              cfg.global_batch_size)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_dataloader = create_dataset(cfg.dataset_name,
                                            'val',
                                            batch_size=100)  #TODO 100?

        print('NOTE: Data prepared')
        print(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))

        optimizer = get_optimizer(cfg, model, use_nesterov=use_nesterov)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()

        # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer,
                              cfg=cfg)

        if engine.distributed:
            print('Distributed training, engine.world_rank={}'.format(
                engine.world_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[engine.world_rank],
                broadcast_buffers=False,
            )
            # model = DistributedDataParallel(model, delay_allreduce=True)
        elif torch.cuda.device_count() > 1:
            print('Single machine multiple GPU training')
            model = torch.nn.parallel.DataParallel(model)

        if tensorflow_style_init:
            for k, v in model.named_parameters():
                if v.dim() in [2, 4]:
                    torch.nn.init.xavier_uniform_(v)
                    print('init {} as xavier_uniform'.format(k))
                if 'bias' in k and 'bn' not in k.lower():
                    torch.nn.init.zeros_(v)
                    print('init {} as zero'.format(k))

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)

        if init_hdf5:
            engine.load_hdf5(init_hdf5)

        kernel_namedvalue_list = engine.get_all_conv_kernel_namedvalue_as_list(
        )

        if os.path.exists(clusters_save_path):
            layer_idx_to_clusters = np.load(clusters_save_path).item()
        else:
            layer_idx_to_clusters = get_layer_idx_to_clusters(
                kernel_namedvalue_list=kernel_namedvalue_list,
                target_deps=target_deps,
                pacesetter_dict=pacesetter_dict)
            if pacesetter_dict is not None:
                for follower_idx, pacesetter_idx in pacesetter_dict.items():
                    if pacesetter_idx in layer_idx_to_clusters:
                        layer_idx_to_clusters[
                            follower_idx] = layer_idx_to_clusters[
                                pacesetter_idx]

            np.save(clusters_save_path, layer_idx_to_clusters)

        csgd_save_file = os.path.join(cfg.output_dir, 'finish.hdf5')

        if os.path.exists(csgd_save_file):
            engine.load_hdf5(csgd_save_file)
        else:
            param_name_to_merge_matrix = generate_merge_matrix_for_kernel(
                deps=cfg.deps,
                layer_idx_to_clusters=layer_idx_to_clusters,
                kernel_namedvalue_list=kernel_namedvalue_list)
            param_name_to_decay_matrix = generate_decay_matrix_for_kernel_and_vecs(
                deps=cfg.deps,
                layer_idx_to_clusters=layer_idx_to_clusters,
                kernel_namedvalue_list=kernel_namedvalue_list,
                weight_decay=cfg.weight_decay,
                centri_strength=centri_strength)
            # if pacesetter_dict is not None:
            #     for follower_idx, pacesetter_idx in pacesetter_dict.items():
            #         follower_kernel_name = kernel_namedvalue_list[follower_idx].name
            #         pacesetter_kernel_name = kernel_namedvalue_list[follower_idx].name
            #         if pacesetter_kernel_name in param_name_to_merge_matrix:
            #             param_name_to_merge_matrix[follower_kernel_name] = param_name_to_merge_matrix[
            #                 pacesetter_kernel_name]
            #             param_name_to_decay_matrix[follower_kernel_name] = param_name_to_decay_matrix[
            #                 pacesetter_kernel_name]

            add_vecs_to_mat_dicts(param_name_to_merge_matrix)

            if show_variables:
                engine.show_variables()

            if beginning_msg:
                engine.log(beginning_msg)

            logger.info("\n\nStart training with pytorch version {}".format(
                torch.__version__))

            iteration = engine.state.iteration
            # done_epochs = iteration // num_train_examples_per_epoch(cfg.dataset_name)
            iters_per_epoch = num_iters_per_epoch(cfg)
            max_iters = iters_per_epoch * cfg.max_epochs
            tb_writer = SummaryWriter(cfg.tb_dir)
            tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

            model.train()

            done_epochs = iteration // iters_per_epoch

            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

            recorded_train_time = 0
            recorded_train_examples = 0

            for epoch in range(done_epochs, cfg.max_epochs):

                pbar = tqdm(range(iters_per_epoch))
                top1 = AvgMeter()
                top5 = AvgMeter()
                losses = AvgMeter()
                discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
                pbar.set_description('Train' + discrip_str)

                if cfg.val_epoch_period > 0 and epoch % cfg.val_epoch_period == 0:
                    model.eval()
                    val_iters = 500 if cfg.dataset_name == 'imagenet' else 100  # use batch_size=100 for val on ImagenNet and CIFAR
                    eval_dict, _ = run_eval(val_dataloader,
                                            val_iters,
                                            model,
                                            criterion,
                                            discrip_str,
                                            dataset_name=cfg.dataset_name)
                    val_top1_value = eval_dict['top1'].item()
                    val_top5_value = eval_dict['top5'].item()
                    val_loss_value = eval_dict['loss'].item()
                    for tag, value in zip(
                            tb_tags,
                        [val_top1_value, val_top5_value, val_loss_value]):
                        tb_writer.add_scalars(tag, {'Val': value}, iteration)
                    engine.log(
                        'validate at epoch {}, top1={:.5f}, top5={:.5f}, loss={:.6f}'
                        .format(epoch, val_top1_value, val_top5_value,
                                val_loss_value))
                    model.train()

                for _ in pbar:

                    start_time = time.time()
                    data, label = load_cuda_data(train_dataloader,
                                                 cfg.dataset_name)
                    data_time = time.time() - start_time

                    train_net_time_start = time.time()
                    acc, acc5, loss = train_one_step(
                        model,
                        data,
                        label,
                        optimizer,
                        criterion,
                        param_name_to_merge_matrix=param_name_to_merge_matrix,
                        param_name_to_decay_matrix=param_name_to_decay_matrix)
                    train_net_time_end = time.time()

                    if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                        recorded_train_examples += cfg.global_batch_size
                        recorded_train_time += train_net_time_end - train_net_time_start

                    scheduler.step()

                    if iteration % cfg.tb_iter_period == 0 and is_main_process:
                        for tag, value in zip(
                                tb_tags,
                            [acc.item(), acc5.item(),
                             loss.item()]):
                            tb_writer.add_scalars(tag, {'Train': value},
                                                  iteration)

                    top1.update(acc.item())
                    top5.update(acc5.item())
                    losses.update(loss.item())

                    pbar_dic = OrderedDict()
                    pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                    pbar_dic['cur_iter'] = iteration
                    pbar_dic['lr'] = scheduler.get_lr()[0]
                    pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                    pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                    pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                    pbar.set_postfix(pbar_dic)

                    if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                        engine.update_iteration(iteration)
                        if (not engine.distributed) or (engine.distributed
                                                        and is_main_process):
                            engine.save_and_link_checkpoint(cfg.output_dir)

                    iteration += 1
                    if iteration >= max_iters:
                        break

                #   do something after an epoch?
                if iteration >= max_iters:
                    break
            #   do something after the training
            if recorded_train_time > 0:
                exp_per_sec = recorded_train_examples / recorded_train_time
            else:
                exp_per_sec = 0
            engine.log(
                'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
                .format(int(TRAIN_SPEED_START * max_iters),
                        int(TRAIN_SPEED_END * max_iters),
                        cfg.global_batch_size, recorded_train_examples,
                        recorded_train_time, exp_per_sec))
            if cfg.save_weights:
                engine.save_checkpoint(cfg.save_weights)
                print('NOTE: training finished, saved to {}'.format(
                    cfg.save_weights))
            engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))

        csgd_prune_and_save(engine=engine,
                            layer_idx_to_clusters=layer_idx_to_clusters,
                            save_file=pruned_weights,
                            succeeding_strategy=succeeding_strategy,
                            new_deps=target_deps)
示例#8
0
def main(args):
    train_generator, _, pred_generator = get_generator(args)

    if args.model_type == "tiny":
        model = Yolov4_tiny(args, training=True)
        if args.use_pretrain:
            if len(os.listdir(os.path.dirname(args.tiny_coco_pretrained_weights))) != 0:
                cur_num_classes = int(args.num_classes)
                args.num_classes = 80
                pretrain_model = Yolov4_tiny(args, training=True)
                pretrain_model.load_weights(args.tiny_coco_pretrained_weights).expect_partial()
                for layer in model.layers:
                    if not layer.get_weights():
                        continue
                    if 'yolov3_head' in layer.name:
                        continue
                    layer.set_weights(pretrain_model.get_layer(layer.name).get_weights())
                args.num_classes = cur_num_classes
                print("Load {} weight successfully!".format(args.model_type))
            else:
                raise ValueError("pretrained_weights directory is empty!")
    elif args.model_type == "p5":
        model = Yolov4(args, training=True)
        if args.use_pretrain:
            if len(os.listdir(os.path.dirname(args.p5_coco_pretrained_weights)))!=0:
                cur_num_classes = int(args.num_classes)
                args.num_classes = 80
                pretrain_model = Yolov4(args, training=True)
                pretrain_model.load_weights(args.p5_coco_pretrained_weights).expect_partial()
                for layer in model.layers:
                    if not layer.get_weights():
                        continue
                    if 'yolov3_head' in layer.name:
                        continue
                    layer.set_weights(pretrain_model.get_layer(layer.name).get_weights())
                args.num_classes = cur_num_classes
                print("Load {} weight successfully!".format(args.model_type))
            else:
                raise ValueError("pretrained_weights directory is empty!")
    elif args.model_type == "p6":
        model = Yolov4(args, training=True)
        if args.use_pretrain:
            if len(os.listdir(os.path.dirname(args.p6_coco_pretrained_weights))) != 0:
                cur_num_classes = int(args.num_classes)
                args.num_classes = 80
                pretrain_model = Yolov4(args, training=True)
                pretrain_model.load_weights(args.p6_coco_pretrained_weights).expect_partial()
                for layer in model.layers:
                    if not layer.get_weights():
                        continue
                    if 'yolov3_head' in layer.name:
                        continue
                    layer.set_weights(pretrain_model.get_layer(layer.name).get_weights())
                args.num_classes = cur_num_classes
                print("Load {} weight successfully!".format(args.model_type))
            else:
                raise ValueError("pretrained_weights directory is empty!")
    else:
        model = Yolov4(args, training=True)
        print("pretrain weight currently don't support p7!")
    num_model_outputs = {"tiny":2, "p5":3,"p6":4,"p7":5}
    loss_fun = [yolov3_loss(args, grid_index) for grid_index in range(num_model_outputs[args.model_type])]
    lr_scheduler = get_lr_scheduler(args)
    optimizer = yolov3_optimizers(args)

    #tensorboard
    open_tensorboard_url = False
    os.system('rm -rf ./logs/')
    tb = program.TensorBoard()
    tb.configure(argv=[None, '--logdir', 'logs','--reload_interval','15'])
    url = tb.launch()
    print("Tensorboard engine is running at {}".format(url))
    best_weight_path = ''

    if args.train_mode == 'fit':
        mAP_writer = tf.summary.create_file_writer("logs/mAP")
        coco_map_callback = CocoMapCallback(pred_generator,model,args,mAP_writer)
        callbacks = [
            tf.keras.callbacks.LearningRateScheduler(lr_scheduler),
            coco_map_callback,
            # ReduceLROnPlateau(verbose=1),
            # EarlyStopping(patience=3, verbose=1),
            # ModelCheckpoint('checkpoints/yolov3_train_{epoch}.tf',verbose=1, save_weights_only=True),
            TensorBoard(log_dir='logs')
        ]
        model.compile(optimizer=optimizer,loss=loss_fun)
        model.fit(train_generator,epochs=args.epochs,
                            callbacks=callbacks,
                            # validation_data=val_dataset,
                            verbose=1,
                            max_queue_size=10,
                            workers=8,
                            use_multiprocessing=False
                            )
        best_weight_path = coco_map_callback.best_weight_path
    else:
        print("loading dataset...")
        start_time = time.perf_counter()
        coco_map = EagerCocoMap(pred_generator, model, args)
        max_coco_map = -1
        max_coco_map_epoch = -1
        accumulate_num = args.accumulated_gradient_num
        accumulate_index = 0
        accum_gradient = [tf.Variable(tf.zeros_like(this_var)) for this_var in model.trainable_variables]

        train_writer = tf.summary.create_file_writer("logs/train")
        mAP_writer = tf.summary.create_file_writer("logs/mAP")
        #training
        for epoch in range(int(args.epochs)):
            lr = lr_scheduler(epoch)
            optimizer.learning_rate.assign(lr)
            remaining_epoches = args.epochs - epoch - 1
            epoch_start_time = time.perf_counter()
            train_loss = 0
            train_generator_tqdm = tqdm(enumerate(train_generator), total=len(train_generator))
            for batch_index, (batch_imgs, batch_labels)  in train_generator_tqdm:
                with tf.GradientTape() as tape:
                    model_outputs = model(batch_imgs, training=True)
                    data_loss = 0
                    for output_index,output_val in enumerate(model_outputs):
                        loss = loss_fun[output_index](batch_labels[output_index], output_val)
                        data_loss += tf.reduce_sum(loss)

                    total_loss = data_loss + args.weight_decay*tf.add_n([tf.nn.l2_loss(v) for v in model.trainable_variables if 'batch_normalization' not in v.name])
                grads = tape.gradient(total_loss, model.trainable_variables)
                accum_gradient = [acum_grad.assign_add(grad) for acum_grad, grad in zip(accum_gradient, grads)]

                accumulate_index += 1
                if accumulate_index == accumulate_num:
                    optimizer.apply_gradients(zip(accum_gradient, model.trainable_variables))
                    accum_gradient = [ grad.assign_sub(grad) for grad in accum_gradient]
                    accumulate_index = 0
                train_loss += total_loss
                train_generator_tqdm.set_description(
                    "epoch:{}/{},train_loss:{:.4f},lr:{:.6f}".format(epoch, args.epochs,
                                                                                     train_loss/(batch_index+1),
                                                                                     optimizer.learning_rate.numpy()))
            train_generator.on_epoch_end()

            with train_writer.as_default():
                tf.summary.scalar("train_loss", train_loss/len(train_generator), step=epoch)
                train_writer.flush()

            #evaluation
            if epoch >= args.start_eval_epoch:
                if epoch % args.eval_epoch_interval == 0:
                    summary_metrics = coco_map.eval()
                    if summary_metrics['Precision/[email protected]'] > max_coco_map:
                        max_coco_map = summary_metrics['Precision/[email protected]']
                        max_coco_map_epoch = epoch
                        best_weight_path = os.path.join(args.checkpoints_dir, 'best_weight_{}_{}_{:.3f}'.format(args.model_type,max_coco_map_epoch, max_coco_map))
                        model.save_weights(best_weight_path)

                    print("max_coco_map:{},epoch:{}".format(max_coco_map,max_coco_map_epoch))
                    with mAP_writer.as_default():
                        tf.summary.scalar("[email protected]", summary_metrics['Precision/[email protected]'], step=epoch)
                        mAP_writer.flush()
            cur_time = time.perf_counter()
            one_epoch_time = cur_time - epoch_start_time
            print("time elapsed: {:.3f} hour, time left: {:.3f} hour".format((cur_time-start_time)/3600,remaining_epoches*one_epoch_time/3600))

            if epoch>0 and not open_tensorboard_url:
                open_tensorboard_url = True
                webbrowser.open(url,new=1)

    print("Training is finished!")
    #save model
    print("Exporting model...")
    if args.output_model_dir and best_weight_path:
        tf.keras.backend.clear_session()
        if args.model_type == "tiny":
            model = Yolov4_tiny(args, training=False)
        else:
            model = Yolov4(args, training=False)
        model.load_weights(best_weight_path)
        best_model_path = os.path.join(args.output_model_dir,best_weight_path.split('/')[-1].replace('weight','model'),'1')
        model.save(best_model_path)
示例#9
0
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(config)

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = len(gpus) > 1
    device = torch.device('cuda:{}'.format(args.local_rank))

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    if args.local_rank == 0:
        logger.info(model)
        tot_params = sum(p.numel() for p in model.parameters()) / 1000000.0
        logger.info(f">>> total params: {tot_params:.2f}M")

        # provide the summary of model
        dump_input = torch.rand(
            (1, 3, config.TRAIN.IMAGE_SIZE[0], config.TRAIN.IMAGE_SIZE[1]))
        logger.info(get_model_summary(model.to(device), dump_input.to(device)))

        # copy model file
        this_dir = os.path.dirname(__file__)
        models_dst_dir = os.path.join(final_output_dir, 'models')
        if os.path.exists(models_dst_dir):
            shutil.rmtree(models_dst_dir)
        shutil.copytree(os.path.join(this_dir, '../lib/models'),
                        models_dst_dir)

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            init_method="env://",
        )

    # prepare data
    train_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TRAIN_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=config.TRAIN.MULTI_SCALE,
        flip=config.TRAIN.FLIP,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TRAIN.BASE_SIZE,
        crop_size=tuple(config.TRAIN.IMAGE_SIZE),  # (height, width)
        scale_factor=config.TRAIN.SCALE_FACTOR)

    if distributed:
        train_sampler = DistributedSampler(train_dataset)
    else:
        train_sampler = None

    trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE and train_sampler is None,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True,
        sampler=train_sampler)

    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=config.TEST.NUM_SAMPLES,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=tuple(config.TEST.IMAGE_SIZE),  # (height, width)
    )

    if distributed:
        test_sampler = DistributedSampler(test_dataset)
    else:
        test_sampler = None

    testloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True,
        sampler=test_sampler)

    # criterion
    if config.LOSS.USE_OHEM:
        criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                     weight=train_dataset.class_weights,
                                     thresh=config.LOSS.OHEMTHRESH,
                                     min_kept=config.LOSS.OHEMKEEP)
    else:
        criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                 weight=train_dataset.class_weights)

    model_state_file = config.MODEL.PRETRAINED
    logger.info('=> Loading model from {}'.format(model_state_file))
    pretrained_dict = torch.load(model_state_file)
    model_dict = model.state_dict()
    pretrained_dict = {
        k[6:]: v
        for k, v in pretrained_dict.items() if k[6:] in model_dict.keys()
    }
    for k, _ in pretrained_dict.items():
        logger.info('=> Loading {} from pretrained model'.format(k))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    model = FullModel(model, criterion)
    if distributed:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(device)
    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    # optimizer
    optimizer = get_optimizer(config, model)

    epoch_iters = np.int(train_dataset.__len__() /
                         config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
    best_mIoU = 0
    last_epoch = 0
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file,
                                    map_location=lambda storage, loc: storage)
            best_mIoU = checkpoint['best_mIoU']
            last_epoch = checkpoint['epoch']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters

    # learning rate scheduler
    lr_scheduler_dict = {
        'optimizer': optimizer,
        'milestones': [s * epoch_iters for s in config.TRAIN.LR_STEP],
        'gamma': config.TRAIN.LR_FACTOR,
        'max_iters': num_iters,
        'last_epoch': last_epoch,
        'epoch_iters': epoch_iters
    }
    lr_scheduler = get_lr_scheduler(config.TRAIN.LR_SCHEDULER,
                                    **lr_scheduler_dict)

    for epoch in range(last_epoch, end_epoch):
        if distributed:
            train_sampler.set_epoch(epoch)
        train(config, epoch, end_epoch, epoch_iters, trainloader, optimizer,
              lr_scheduler, model, writer_dict, device)

        valid_loss, mean_IoU = validate(config, testloader, model, writer_dict,
                                        device)

        if args.local_rank == 0:
            logger.info(
                '=> saving checkpoint to {}'.format(final_output_dir +
                                                    '/checkpoint.pth.tar'))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'best_mIoU': best_mIoU,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, os.path.join(final_output_dir, 'checkpoint.pth.tar'))

            if mean_IoU > best_mIoU:
                best_mIoU = mean_IoU
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'best.pth'))
            msg = f'Loss: {valid_loss:.4f}, MeanIU: {mean_IoU: 4.4f}, \
                        Best_mIoU: {best_mIoU: 4.4f}'

            logger.info(msg)

            if epoch == end_epoch - 1:
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'final_state.pth'))

                writer_dict['writer'].close()
                end = timeit.default_timer()
                logger.info(f'Hours: {np.int((end-start)/3600)}')
                logger.info('Done!')
示例#10
0
def aofp_train_main(local_rank,
                    target_layers,
                    succ_strategy,
                    warmup_iterations,
                    aofp_batches_per_half,
                    flops_func,
                    cfg: BaseConfigByEpoch,
                    net=None,
                    train_dataloader=None,
                    val_dataloader=None,
                    show_variables=False,
                    convbuilder=None,
                    init_hdf5=None,
                    no_l2_keywords='depth',
                    gradient_mask=None,
                    use_nesterov=False,
                    tensorflow_style_init=False,
                    keyword_to_lr_mult=None,
                    auto_continue=False,
                    lasso_keyword_to_strength=None,
                    save_hdf5_epochs=10000,
                    remain_flops_ratio=0):

    if no_l2_keywords is None:
        no_l2_keywords = []
    if type(no_l2_keywords) is not list:
        no_l2_keywords = [no_l2_keywords]

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    with Engine(local_rank=local_rank) as engine:
        engine.setup_log(name='train',
                         log_dir=cfg.output_dir,
                         file_name='log.txt')

        # ----------------------------- build model ------------------------------
        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)
        if net is None:
            net_fn = get_model_fn(cfg.dataset_name, cfg.network_type)
            model = net_fn(cfg, convbuilder)
        else:
            model = net
        model = model.cuda()
        # ----------------------------- model done ------------------------------

        # ---------------------------- prepare data -------------------------
        if train_dataloader is None:
            train_data = create_dataset(cfg.dataset_name,
                                        cfg.dataset_subset,
                                        cfg.global_batch_size,
                                        distributed=engine.distributed)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_data = create_dataset(cfg.dataset_name,
                                      'val',
                                      global_batch_size=100,
                                      distributed=False)
        engine.echo('NOTE: Data prepared')
        engine.echo(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))
        # ----------------------------- data done --------------------------------

        # ------------------------ parepare optimizer, scheduler, criterion -------
        optimizer = get_optimizer(engine,
                                  cfg,
                                  model,
                                  no_l2_keywords=no_l2_keywords,
                                  use_nesterov=use_nesterov,
                                  keyword_to_lr_mult=keyword_to_lr_mult)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()
        # --------------------------------- done -------------------------------

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer)

        if engine.distributed:
            torch.cuda.set_device(local_rank)
            engine.echo('Distributed training, device {}'.format(local_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                broadcast_buffers=False,
            )
        else:
            assert torch.cuda.device_count() == 1
            engine.echo('Single GPU training')

        if tensorflow_style_init:
            init_as_tensorflow(model)
        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)
        if init_hdf5:
            engine.load_part('base_path.', init_hdf5)
        if auto_continue:
            assert cfg.init_weights is None
            engine.load_checkpoint(get_last_checkpoint(cfg.output_dir))
        if show_variables:
            engine.show_variables()

        # ------------ do training ---------------------------- #
        engine.log("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch
        last_epoch_done_iters = iteration % iters_per_epoch

        if done_epochs == 0 and last_epoch_done_iters == 0:
            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        recorded_train_time = 0
        recorded_train_examples = 0

        collected_train_loss_sum = 0
        collected_train_loss_count = 0

        if gradient_mask is not None:
            gradient_mask_tensor = {}
            for name, value in gradient_mask.items():
                gradient_mask_tensor[name] = torch.Tensor(value).cuda()
        else:
            gradient_mask_tensor = None

        #########################   aofp
        _init_interval = aofp_batches_per_half // len(target_layers)
        layer_to_start_iter = {
            i: (_init_interval * i + warmup_iterations)
            for i in target_layers
        }
        print(
            'the initial layer_to_start_iter = {}'.format(layer_to_start_iter))
        #   0.  get all the AOFPLayers
        layer_idx_to_module = {}
        for submodule in model.modules():
            if hasattr(submodule, 'score_mask') or hasattr(
                    submodule, 't_value'):
                layer_idx_to_module[submodule.conv_idx] = submodule
        print(layer_idx_to_module)
        ######################################

        for epoch in range(done_epochs, cfg.max_epochs):

            if engine.distributed and hasattr(train_data, 'train_sampler'):
                train_data.train_sampler.set_epoch(epoch)

            if epoch == done_epochs:
                pbar = tqdm(range(iters_per_epoch - last_epoch_done_iters))
            else:
                pbar = tqdm(range(iters_per_epoch))

            if epoch == 0 and local_rank == 0:
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str='Init',
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_data,
                                             dataset_name=cfg.dataset_name)

                # load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)

                train_net_time_start = time.time()

                ############    aofp
                #   1.  see if it is time to start on every layer
                #   2.  forward and accumulate
                #   3.  if a half on some layer is finished, do something
                #   ----    fetch its accumulated t vectors, analyze the first 'granu' elements
                #   ----    if good enough, set the base mask, reset the search space
                #   ----    elif granu == 1, do nothing
                #   ----    else, granu /= 2, reset the search space
                for layer_idx, start_iter in layer_to_start_iter.items():
                    if start_iter == iteration:
                        layer_idx_to_module[layer_idx].start_aofp(iteration)
                acc, acc5, loss = train_one_step(
                    model,
                    data,
                    label,
                    optimizer,
                    criterion,
                    if_accum_grad,
                    gradient_mask_tensor=gradient_mask_tensor,
                    lasso_keyword_to_strength=lasso_keyword_to_strength)
                for layer_idx, aofp_layer in layer_idx_to_module.items():
                    #   accumulate
                    if layer_idx not in succ_strategy:
                        continue
                    follow_layer_idx = succ_strategy[layer_idx]
                    if follow_layer_idx not in layer_idx_to_module:
                        continue
                    t_value = layer_idx_to_module[follow_layer_idx].t_value
                    aofp_layer.accumulate_t_value(t_value)
                    if aofp_layer.finished_a_half(iteration):
                        aofp_layer.halve_or_stop(iteration)
                ###################################

                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                for module in model.modules():
                    if hasattr(module, 'set_cur_iter'):
                        module.set_cur_iter(iteration)

                if iteration % cfg.tb_iter_period == 0 and engine.world_rank == 0:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                if epoch >= cfg.max_epochs - COLLECT_TRAIN_LOSS_EPOCHS:
                    collected_train_loss_sum += loss.item()
                    collected_train_loss_count += 1

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                iteration += 1

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and
                                                    engine.world_rank == 0):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            engine.update_iteration(iteration)
            engine.save_latest_ckpt(cfg.output_dir)

            if (epoch + 1) % save_hdf5_epochs == 0:
                engine.save_hdf5(
                    os.path.join(cfg.output_dir,
                                 'epoch-{}.hdf5'.format(epoch)))

            if local_rank == 0 and \
                    cfg.val_epoch_period > 0 and (epoch >= cfg.max_epochs - 10 or epoch % cfg.val_epoch_period == 0):
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str=discrip_str,
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            cur_deps = np.array(cfg.deps)
            for submodule in model.modules():
                if hasattr(submodule, 'base_mask'):
                    cur_deps[submodule.conv_idx] = np.sum(
                        submodule.base_mask.cpu().numpy() == 1)
            origin_flops = flops_func(cfg.deps)
            cur_flops = flops_func(cur_deps)
            remain_ratio = cur_flops / origin_flops
            if local_rank == 0:
                print('##########################')
                print('origin deps ', cfg.deps)
                print('cur deps ', cur_deps)
                print('remain flops ratio = ', remain_ratio, 'the target is ',
                      remain_flops_ratio)
                print('##########################')
            if remain_ratio < remain_flops_ratio:
                break
            if iteration >= max_iters:
                break

        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters),
                    int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(
                cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
        if collected_train_loss_count > 0:
            engine.log(
                'TRAIN LOSS collected over last {} epochs: {:.6f}'.format(
                    COLLECT_TRAIN_LOSS_EPOCHS,
                    collected_train_loss_sum / collected_train_loss_count))

        final_deps = aofp_prune(model,
                                origin_deps=cfg.deps,
                                succ_strategy=succ_strategy,
                                save_path=os.path.join(cfg.output_dir,
                                                       'finish_pruned.hdf5'))
        origin_flops = flops_func(cfg.deps)
        cur_flops = flops_func(final_deps)
        engine.log(
            '##################################################################'
        )
        engine.log(cfg.network_type)
        engine.log('origin width: {} , flops {} '.format(
            cfg.deps, origin_flops))
        engine.log('final width: {}, flops {} '.format(final_deps, cur_flops))
        engine.log('flops reduction: {}'.format(1 - cur_flops / origin_flops))
        return final_deps