Ejemplo n.º 1
0
    def __init__(self, builder: ConvBuilder, in_planes, stage_deps, stride=1):
        super(ResNetBottleneckStage, self).__init__()
        print('building stage: in {}, deps {}'.format(in_planes, stage_deps))
        assert (len(stage_deps) - 1) % 3 == 0
        self.num_blocks = (len(stage_deps) - 1) // 3
        stage_out_channels = stage_deps[3]
        for i in range(2, self.num_blocks):
            assert stage_deps[3 * i] == stage_out_channels

        self.relu = builder.ReLU()

        self.projection = builder.Conv2dBN(in_channels=in_planes,
                                           out_channels=stage_deps[0],
                                           kernel_size=1,
                                           stride=stride)
        self.align_opr = builder.ResNetAlignOpr(channels=stage_deps[0])

        for i in range(self.num_blocks):
            in_c = in_planes if i == 0 else stage_out_channels
            block_stride = stride if i == 0 else 1
            self.__setattr__(
                'block{}'.format(i),
                BottleneckBranch(builder=builder,
                                 in_channels=in_c,
                                 deps=stage_deps[1 + i * 3:4 + i * 3],
                                 stride=block_stride))
Ejemplo n.º 2
0
    def __init__(self, builder:ConvBuilder, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = builder.Conv2dBNReLU(in_channels=in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1)
        self.conv2 = builder.Conv2dBN(in_channels=planes, out_channels=self.expansion * planes, kernel_size=3, stride=1, padding=1)

        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = builder.Conv2dBN(in_channels=in_planes, out_channels=self.expansion * planes, kernel_size=1, stride=stride)
        else:
            self.shortcut = builder.ResIdentity(num_channels=in_planes)
Ejemplo n.º 3
0
 def __init__(self, num_classes, builder: ConvBuilder, deps):
     super(VCNet, self).__init__()
     self.stem = _create_vgg_stem(builder=builder, deps=deps)
     self.flatten = builder.Flatten()
     self.linear1 = builder.IntermediateLinear(in_features=deps[12],
                                               out_features=512)
     self.relu = builder.ReLU()
     self.linear2 = builder.Linear(in_features=512,
                                   out_features=num_classes)
Ejemplo n.º 4
0
    def __init__(self,
                 builder: ConvBuilder,
                 num_blocks,
                 num_classes=1000,
                 deps=None):
        super(SBottleneckResNet, self).__init__()
        # self.mean_tensor = torch.from_numpy(np.array([0.485, 0.456, 0.406])).reshape(1, 3, 1, 1).cuda().type(torch.cuda.FloatTensor)
        # self.std_tensor = torch.from_numpy(np.array([0.229, 0.224, 0.225])).reshape(1, 3, 1, 1).cuda().type(torch.cuda.FloatTensor)

        # self.mean_tensor = torch.from_numpy(np.array([0.406, 0.456, 0.485])).reshape(1, 3, 1, 1).cuda().type(
        #     torch.cuda.FloatTensor)
        # self.std_tensor = torch.from_numpy(np.array([0.225, 0.224, 0.229])).reshape(1, 3, 1, 1).cuda().type(
        #     torch.cuda.FloatTensor)

        # self.mean_tensor = torch.from_numpy(np.array([0.5, 0.5, 0.5])).reshape(1, 3, 1, 1).cuda().type(
        #     torch.cuda.FloatTensor)
        # self.std_tensor = torch.from_numpy(np.array([0.5, 0.5, 0.5])).reshape(1, 3, 1, 1).cuda().type(
        #     torch.cuda.FloatTensor)

        if deps is None:
            if num_blocks == [3, 4, 6, 3]:
                deps = RESNET50_ORIGIN_DEPS_FLATTENED
            elif num_blocks == [3, 4, 23, 3]:
                deps = resnet_bottleneck_origin_deps_flattened(101)
            else:
                raise ValueError('???')

        self.conv1 = builder.Conv2dBNReLU(3,
                                          deps[0],
                                          kernel_size=7,
                                          stride=2,
                                          padding=3)
        self.maxpool = builder.Maxpool2d(kernel_size=3, stride=2, padding=1)
        #   every stage has  num_block * 3 + 1
        nls = [n * 3 + 1 for n in num_blocks]  # num layers in each stage
        self.stage1 = ResNetBottleneckStage(builder=builder,
                                            in_planes=deps[0],
                                            stage_deps=deps[1:nls[0] + 1])
        self.stage2 = ResNetBottleneckStage(builder=builder,
                                            in_planes=deps[nls[0]],
                                            stage_deps=deps[nls[0] + 1:nls[0] +
                                                            1 + nls[1]],
                                            stride=2)
        self.stage3 = ResNetBottleneckStage(
            builder=builder,
            in_planes=deps[nls[0] + nls[1]],
            stage_deps=deps[nls[0] + nls[1] + 1:nls[0] + 1 + nls[1] + nls[2]],
            stride=2)
        self.stage4 = ResNetBottleneckStage(
            builder=builder,
            in_planes=deps[nls[0] + nls[1] + nls[2]],
            stage_deps=deps[nls[0] + nls[1] + nls[2] + 1:nls[0] + 1 + nls[1] +
                            nls[2] + nls[3]],
            stride=2)
        self.gap = builder.GAP(kernel_size=7)
        self.fc = builder.Linear(deps[-1], num_classes)
Ejemplo n.º 5
0
 def __init__(self, builder: ConvBuilder, in_channels, deps, stride=1):
     super(BottleneckBranch, self).__init__()
     assert len(deps) == 3
     self.conv1 = builder.Conv2dBNReLU(in_channels, deps[0], kernel_size=1)
     self.conv2 = builder.Conv2dBNReLU(deps[0],
                                       deps[1],
                                       kernel_size=3,
                                       stride=stride,
                                       padding=1)
     self.conv3 = builder.Conv2dBN(deps[1], deps[2], kernel_size=1)
Ejemplo n.º 6
0
 def __init__(self, builder: ConvBuilder, deps):
     super(LeNet5, self).__init__()
     self.bd = builder
     stem = builder.Sequential()
     stem.add_module(
         'conv1',
         builder.Conv2d(in_channels=1,
                        out_channels=LENET5_DEPS[0],
                        kernel_size=5,
                        bias=True))
     stem.add_module('maxpool1', builder.Maxpool2d(kernel_size=2))
     stem.add_module(
         'conv2',
         builder.Conv2d(in_channels=LENET5_DEPS[0],
                        out_channels=LENET5_DEPS[1],
                        kernel_size=5,
                        bias=True))
     stem.add_module('maxpool2', builder.Maxpool2d(kernel_size=2))
     self.stem = stem
     self.flatten = builder.Flatten()
     self.linear1 = builder.Linear(in_features=LENET5_DEPS[1] * 16,
                                   out_features=LENET5_DEPS[2])
     self.relu1 = builder.ReLU()
     self.linear2 = builder.Linear(in_features=LENET5_DEPS[2],
                                   out_features=10)
Ejemplo n.º 7
0
 def __init__(self, builder:ConvBuilder, num_classes):
     super(MobileV1CifarNet, self).__init__()
     self.conv1 = builder.Conv2dBNReLU(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
     blocks = []
     in_planes = cifar_cfg[0]
     for x in cifar_cfg:
         out_planes = x if isinstance(x, int) else x[0]
         stride = 1 if isinstance(x, int) else x[1]
         blocks.append(MobileV1Block(builder=builder, in_planes=in_planes, out_planes=out_planes, stride=stride))
         in_planes = out_planes
     self.stem = builder.Sequential(*blocks)
     self.gap = builder.GAP(kernel_size=8)
     self.linear = builder.Linear(cifar_cfg[-1], num_classes)
Ejemplo n.º 8
0
 def __init__(self, builder: ConvBuilder, in_planes, out_planes, stride=1):
     super(MobileV1Block, self).__init__()
     self.conv1 = builder.Conv2dBNReLU(in_channels=in_planes,
                                       out_channels=in_planes,
                                       kernel_size=3,
                                       stride=stride,
                                       padding=1,
                                       groups=in_planes)
     self.conv2 = builder.Conv2dBNReLU(in_channels=in_planes,
                                       out_channels=out_planes,
                                       kernel_size=1,
                                       stride=1,
                                       padding=0)
Ejemplo n.º 9
0
    def __init__(self, block_counts, num_classes, builder:ConvBuilder, deps, use_dropout):
        super(WRNCifarNet, self).__init__()
        self.bd = builder
        converted_deps = wrn_convert_flattened_deps(deps)
        print('the converted deps is ', converted_deps)

        self.conv1 = builder.Conv2d(in_channels=3, out_channels=converted_deps[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.stage1 = self._build_wrn_stage(num_blocks=block_counts[0], stage_input_channels=converted_deps[0],
                                            stage_deps=converted_deps[1], downsample=False, use_dropout=use_dropout)
        self.stage2 = self._build_wrn_stage(num_blocks=block_counts[1], stage_input_channels=converted_deps[1][-1][1],
                                            stage_deps=converted_deps[2], downsample=True, use_dropout=use_dropout)
        self.stage3 = self._build_wrn_stage(num_blocks=block_counts[2], stage_input_channels=converted_deps[2][-1][1],
                                            stage_deps=converted_deps[3], downsample=True, use_dropout=use_dropout)
        self.last_bn = builder.BatchNorm2d(num_features=converted_deps[3][-1][1])
        self.linear = builder.Linear(in_features=converted_deps[3][-1][1], out_features=num_classes)
Ejemplo n.º 10
0
 def __init__(self,
              conv_idx,
              builder: ConvBuilder,
              preced_layer_idx,
              in_features,
              out_features,
              bias=True):
     super(AOFPFCReluLayer, self).__init__()
     self.conv_idx = conv_idx
     self.base_path = builder.Linear(in_features=in_features,
                                     out_features=out_features,
                                     bias=bias)
     self.relu = builder.ReLU()
     self.register_buffer('t_value', torch.zeros(1))
     self.preced_layer_idx = preced_layer_idx
Ejemplo n.º 11
0
    def __init__(self, conv_idx, builder:ConvBuilder, preced_layer_idx,
                 score_preced_layer, scored_by_follow_layer, iters_per_half, thresh,
                 in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros'):
        super(AOFPConvBNLayer, self).__init__()
        self.conv_idx = conv_idx
        self.base_path = builder.OriginConv2dBN(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                                stride=stride, padding=padding, dilation=dilation, groups=groups,
                                                padding_mode=padding_mode)

        if score_preced_layer:
            self.register_buffer('t_value', torch.zeros(1))
        if scored_by_follow_layer:
            self.register_buffer('half_start_iter', torch.zeros(1) + 9999999)
            self.register_buffer('base_mask', torch.ones(out_channels))
            self.register_buffer('score_mask', torch.ones(out_channels))
            self.register_buffer('search_space', torch.ones(out_channels))  # 1 indiates "in the search space"
            # self.reset_search_space()
            self.register_buffer('accumulated_t', torch.zeros(out_channels))
            self.register_buffer('accumulated_cnt', torch.zeros(out_channels))

        self.post = nn.Identity()
        self.preced_layer_idx = preced_layer_idx
        self.scored_by_follow_layer = scored_by_follow_layer
        self.score_preced_layer = score_preced_layer
        self.num_filters = out_channels
        self.iters_per_half = iters_per_half
        self.thresh = thresh
        self.aofp_started = False
Ejemplo n.º 12
0
    def __init__(self, builder:ConvBuilder, num_classes, deps=None):
        super(MobileV1ImagenetNet, self).__init__()
        if deps is None:
            deps = MI1_ORIGIN_DEPS
        assert len(deps) == 27
        self.conv1 = builder.Conv2dBNReLU(in_channels=3, out_channels=deps[0], kernel_size=3, stride=2, padding=1)
        blocks = []
        for block_idx in range(13):
            depthwise_channels = int(deps[block_idx * 2 + 1])
            pointwise_channels = int(deps[block_idx * 2 + 2])
            stride = 2 if block_idx in [1, 3, 5, 11] else 1
            blocks.append(MobileV1Block(builder=builder, in_planes=depthwise_channels, out_planes=pointwise_channels, stride=stride))

        self.stem = builder.Sequential(*blocks)
        self.gap = builder.GAP(kernel_size=7)
        self.linear = builder.Linear(imagenet_cfg[-1], num_classes)
Ejemplo n.º 13
0
 def __init__(self, builder:ConvBuilder, deps):
     super(LeNet5BN, self).__init__()
     self.bd = builder
     stem = builder.Sequential()
     stem.add_module('conv1', builder.Conv2dBNReLU(in_channels=1, out_channels=deps[0], kernel_size=5))
     stem.add_module('maxpool1', builder.Maxpool2d(kernel_size=2))
     stem.add_module('conv2', builder.Conv2dBNReLU(in_channels=deps[0], out_channels=deps[1], kernel_size=5))
     stem.add_module('maxpool2', builder.Maxpool2d(kernel_size=2))
     self.stem = stem
     self.flatten = builder.Flatten()
     self.linear1 = builder.IntermediateLinear(in_features=deps[1] * 16, out_features=500)
     self.relu1 = builder.ReLU()
     self.linear2 = builder.Linear(in_features=500, out_features=10)
Ejemplo n.º 14
0
def ding_test(cfg:BaseConfigByEpoch, net=None, val_dataloader=None, show_variables=False, convbuilder=None,
               init_hdf5=None, extra_msg=None, weights_dict=None):

    with Engine(local_rank=0, for_val_only=True) as engine:

        engine.setup_log(
            name='test', log_dir='./', file_name=DETAIL_LOG_FILE)

        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)

        if net is None:
            net_fn = get_model_fn(cfg.dataset_name, cfg.network_type)
            model = net_fn(cfg, convbuilder).cuda()
        else:
            model = net.cuda()

        if val_dataloader is None:
            val_data = create_dataset(cfg.dataset_name, cfg.dataset_subset,
                                      global_batch_size=cfg.global_batch_size, distributed=False)
        num_examples = num_val_examples(cfg.dataset_name)
        assert num_examples % cfg.global_batch_size == 0
        val_iters = num_val_examples(cfg.dataset_name) // cfg.global_batch_size
        print('batchsize={}, {} iters'.format(cfg.global_batch_size, val_iters))

        criterion = get_criterion(cfg).cuda()

        engine.register_state(
            scheduler=None, model=model, optimizer=None)

        if show_variables:
            engine.show_variables()

        assert not engine.distributed

        if weights_dict is not None:
            engine.load_from_weights_dict(weights_dict)
        else:
            if cfg.init_weights:
                engine.load_checkpoint(cfg.init_weights)
            if init_hdf5:
                engine.load_hdf5(init_hdf5)

        # engine.save_by_order('smi2_by_order.hdf5')
        # engine.load_by_order('smi2_by_order.hdf5')
        # engine.save_hdf5('model_files/stami2_lrs4Z.hdf5')

        model.eval()
        eval_dict, total_net_time = run_eval(val_data, val_iters, model, criterion, 'TEST', dataset_name=cfg.dataset_name)
        val_top1_value = eval_dict['top1'].item()
        val_top5_value = eval_dict['top5'].item()
        val_loss_value = eval_dict['loss'].item()

        msg = '{},{},{},top1={:.5f},top5={:.5f},loss={:.7f},total_net_time={}'.format(cfg.network_type, init_hdf5 or cfg.init_weights, cfg.dataset_subset,
                                                                    val_top1_value, val_top5_value, val_loss_value, total_net_time)
        if extra_msg is not None:
            msg += ', ' + extra_msg
        log_important(msg, OVERALL_LOG_FILE)
        return eval_dict
Ejemplo n.º 15
0
    def __init__(self,  builder:ConvBuilder, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        # self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        # self.bn1 = nn.BatchNorm2d(planes)
        self.conv1 = builder.Conv2dBN(inplanes, planes, kernel_size=1)

        # self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
        #                        padding=1, bias=False)
        # self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = builder.Conv2dBN(planes, planes, kernel_size=3, stride=stride, padding=1)

        # self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        # self.bn3 = nn.BatchNorm2d(planes * 4)
        self.conv3 = builder.Conv2dBN(planes, planes * 4, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
Ejemplo n.º 16
0
    def __init__(self,
                 builder: ConvBuilder,
                 in_planes,
                 stage_deps,
                 stride=1,
                 is_first=False):
        super(ResNetBasicStage, self).__init__()
        print('building stage: in {}, deps {}'.format(in_planes, stage_deps))
        self.num_blocks = len(stage_deps) // 2

        stage_out_channels = stage_deps[0]
        for i in range(0, self.num_blocks):
            assert stage_deps[i * 2 + 2] == stage_out_channels

        if is_first:
            self.conv1 = builder.Conv2dBN(in_channels=in_planes,
                                          out_channels=stage_out_channels,
                                          kernel_size=3,
                                          stride=1,
                                          padding=1)
            # self.projection = builder.ResIdentity(num_channels=stage_out_channels)
        else:
            self.projection = builder.Conv2dBN(in_channels=in_planes,
                                               out_channels=stage_out_channels,
                                               kernel_size=1,
                                               stride=stride)

        self.relu = builder.ReLU()
        self.align_opr = builder.ResNetAlignOpr(channels=stage_out_channels)

        for i in range(self.num_blocks):
            if i == 0 and is_first:
                in_c = stage_deps[0]
            elif i == 0:
                in_c = in_planes
            else:
                in_c = stage_out_channels
            block_stride = stride if i == 0 else 1
            self.__setattr__(
                'block{}'.format(i),
                BasicBranch(builder=builder,
                            in_channels=in_c,
                            deps=stage_deps[1 + i * 2:3 + i * 2],
                            stride=block_stride))
Ejemplo n.º 17
0
    def __init__(self,
                 builder: ConvBuilder,
                 num_blocks,
                 num_classes=1000,
                 deps=None):
        super(SBottleneckResNet, self).__init__()

        if deps is None:
            if num_blocks == [3, 4, 6, 3]:
                deps = RESNET50_ORIGIN_DEPS_FLATTENED
            elif num_blocks == [3, 4, 23, 3]:
                deps = resnet_bottleneck_origin_deps_flattened(101)
            else:
                raise ValueError('???')

        self.conv1 = builder.Conv2dBNReLU(3,
                                          deps[0],
                                          kernel_size=7,
                                          stride=2,
                                          padding=3)
        self.maxpool = builder.Maxpool2d(kernel_size=3, stride=2, padding=1)
        #   every stage has  num_block * 3 + 1
        nls = [n * 3 + 1 for n in num_blocks]  # num layers in each stage
        self.stage1 = ResNetBottleneckStage(builder=builder,
                                            in_planes=deps[0],
                                            stage_deps=deps[1:nls[0] + 1])
        self.stage2 = ResNetBottleneckStage(builder=builder,
                                            in_planes=deps[nls[0]],
                                            stage_deps=deps[nls[0] + 1:nls[0] +
                                                            1 + nls[1]],
                                            stride=2)
        self.stage3 = ResNetBottleneckStage(
            builder=builder,
            in_planes=deps[nls[0] + nls[1]],
            stage_deps=deps[nls[0] + nls[1] + 1:nls[0] + 1 + nls[1] + nls[2]],
            stride=2)
        self.stage4 = ResNetBottleneckStage(
            builder=builder,
            in_planes=deps[nls[0] + nls[1] + nls[2]],
            stage_deps=deps[nls[0] + nls[1] + nls[2] + 1:nls[0] + 1 + nls[1] +
                            nls[2] + nls[3]],
            stride=2)
        self.gap = builder.GAP(kernel_size=7)
        self.fc = builder.Linear(deps[-1], num_classes)
Ejemplo n.º 18
0
def ding_test(cfg:BaseConfigByEpoch, net=None, val_dataloader=None, show_variables=False, convbuilder=None,
               init_hdf5=None, ):

    with Engine() as engine:

        engine.setup_log(
            name='test', log_dir='./', file_name=DETAIL_LOG_FILE)

        if net is None:
            net = get_model_fn(cfg.dataset_name, cfg.network_type)

        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)

        model = net(cfg, convbuilder).cuda()

        if val_dataloader is None:
            val_dataloader = create_dataset(cfg.dataset_name, cfg.dataset_subset, batch_size=cfg.global_batch_size)
        val_iters = 50000 // cfg.global_batch_size if cfg.dataset_name == 'imagenet' else 10000 // cfg.global_batch_size

        print('NOTE: Data prepared')
        print('NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'.format(cfg.global_batch_size, torch.cuda.device_count(), torch.cuda.memory_allocated()))

        criterion = get_criterion(cfg).cuda()

        engine.register_state(
            scheduler=None, model=model, optimizer=None, cfg=cfg)

        if show_variables:
            engine.show_variables()

        if engine.distributed:
            print('Distributed training, engine.world_rank={}'.format(engine.world_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[engine.world_rank],
                broadcast_buffers=False, )
            # model = DistributedDataParallel(model, delay_allreduce=True)
        elif torch.cuda.device_count() > 1:
            print('Single machine multiple GPU training')
            model = torch.nn.parallel.DataParallel(model)

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights, just_weights=True)

        if init_hdf5:
            engine.load_hdf5(init_hdf5)

        model.eval()
        eval_dict, _ = run_eval(val_dataloader, val_iters, model, criterion, 'TEST', dataset_name=cfg.dataset_name)
        val_top1_value = eval_dict['top1'].item()
        val_top5_value = eval_dict['top5'].item()
        val_loss_value = eval_dict['loss'].item()

        msg = '{},{},{},top1={:.5f},top5={:.5f},loss={:.7f}'.format(cfg.network_type, init_hdf5 or cfg.init_weights, cfg.dataset_subset,
                                                                    val_top1_value, val_top5_value, val_loss_value)
        log_important(msg, OVERALL_LOG_FILE)
Ejemplo n.º 19
0
 def __init__(self, builder:ConvBuilder, block, num_blocks, num_classes=10):
     super(ResNet, self).__init__()
     self.bd = builder
     self.in_planes = 64
     self.conv1 = builder.Conv2dBNReLU(3, 64, kernel_size=7, stride=2, padding=3)
     self.stage1 = self._make_stage(block, 64, num_blocks[0], stride=1)
     self.stage2 = self._make_stage(block, 128, num_blocks[1], stride=2)
     self.stage3 = self._make_stage(block, 256, num_blocks[2], stride=2)
     self.stage4 = self._make_stage(block, 512, num_blocks[3], stride=2)
     self.linear = self.bd.Linear(512*block.expansion, num_classes)
Ejemplo n.º 20
0
    def __init__(self,
                 builder: ConvBuilder,
                 block,
                 num_blocks,
                 num_classes=10,
                 width_multiplier=None):
        super(ResNet, self).__init__()

        print('width multiplier: ', width_multiplier)

        if width_multiplier is None:
            width_multiplier = 1
        else:
            width_multiplier = width_multiplier[0]

        self.bd = builder
        self.in_planes = int(64 * width_multiplier)
        self.conv1 = builder.Conv2dBNReLU(3,
                                          int(64 * width_multiplier),
                                          kernel_size=7,
                                          stride=2,
                                          padding=3)
        self.stage1 = self._make_stage(block,
                                       int(64 * width_multiplier),
                                       num_blocks[0],
                                       stride=1)
        self.stage2 = self._make_stage(block,
                                       int(128 * width_multiplier),
                                       num_blocks[1],
                                       stride=2)
        self.stage3 = self._make_stage(block,
                                       int(256 * width_multiplier),
                                       num_blocks[2],
                                       stride=2)
        self.stage4 = self._make_stage(block,
                                       int(512 * width_multiplier),
                                       num_blocks[3],
                                       stride=2)
        self.gap = builder.GAP(kernel_size=7)
        self.linear = self.bd.Linear(
            int(512 * block.expansion * width_multiplier), num_classes)
Ejemplo n.º 21
0
    def __init__(self, input_channels, block_channels, stride,
                 projection_shortcut, use_dropout, builder: ConvBuilder):
        super(WRNCifarBlock, self).__init__()
        assert len(block_channels) == 2

        if projection_shortcut:
            self.proj = builder.BNReLUConv2d(in_channels=input_channels,
                                             out_channels=block_channels[1],
                                             kernel_size=1,
                                             stride=stride,
                                             padding=0)
        else:
            self.proj = builder.ResIdentity(num_channels=block_channels[1])

        self.conv1 = builder.BNReLUConv2d(in_channels=input_channels,
                                          out_channels=block_channels[0],
                                          kernel_size=3,
                                          stride=stride,
                                          padding=1)
        if use_dropout:
            self.dropout = builder.Dropout(keep_prob=0.7)
            print('use dropout for WRN')
        else:
            self.dropout = builder.Identity()
        self.conv2 = builder.BNReLUConv2d(in_channels=block_channels[0],
                                          out_channels=block_channels[1],
                                          kernel_size=3,
                                          stride=1,
                                          padding=1)
Ejemplo n.º 22
0
 def __init__(self, builder:ConvBuilder):
     super(LeNet300, self).__init__()
     self.flatten = builder.Flatten()
     self.linear1 = builder.Linear(in_features=28*28, out_features=300, bias=True)
     self.relu1 = builder.ReLU()
     self.linear2 = builder.Linear(in_features=300, out_features=100, bias=True)
     self.relu2 = builder.ReLU()
     self.linear3 = builder.Linear(in_features=100, out_features=10, bias=True)
Ejemplo n.º 23
0
def get_pose_net(cfg, is_train, joint_num):


    if cfg.acb:
        builder = ACNetBuilder(base_config=None, deploy=False, gamma_init=1)
    else:
        builder = ConvBuilder(base_config=None)

    backbone = ResNetBackbone(builder, cfg.resnet_type)
    head_net = HeadNet(joint_num)
    if is_train:
        backbone.init_weights()
        head_net.init_weights()

    model = ResPoseNet(backbone, head_net, joint_num)
    return model
Ejemplo n.º 24
0
    def __init__(self, builder:ConvBuilder, resnet_type):
	
        resnet_spec = {#18: (BasicBlock, [2, 2, 2, 2], [64, 64, 128, 256, 512], 'resnet18'),
		       #34: (BasicBlock, [3, 4, 6, 3], [64, 64, 128, 256, 512], 'resnet34'),
		       50: (Bottleneck, [3, 4, 6, 3], [64, 256, 512, 1024, 2048], 'resnet50'),
		       101: (Bottleneck, [3, 4, 23, 3], [64, 256, 512, 1024, 2048], 'resnet101'),
		       152: (Bottleneck, [3, 8, 36, 3], [64, 256, 512, 1024, 2048], 'resnet152')}
        block, layers, channels, name = resnet_spec[resnet_type]
        
        self.name = name
        self.inplanes = 64

        super(ResNetBackbone, self).__init__()
        # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        #                        bias=False)
        # self.bn1 = nn.BatchNorm2d(64)
        self.bd = builder
        self.conv1 = builder.Conv2dBN(3, 64, kernel_size=7, stride=2, padding=3)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
Ejemplo n.º 25
0
 def __init__(self, num_classes, builder: ConvBuilder, deps):
     super(VANet, self).__init__()
     sq = builder.Sequential()
     sq.add_module(
         'conv1',
         builder.Conv2dBNReLU(in_channels=3,
                              out_channels=deps[0],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module(
         'conv2',
         builder.Conv2dBNReLU(in_channels=deps[0],
                              out_channels=deps[1],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module('maxpool1', builder.Maxpool2d(kernel_size=2))
     sq.add_module(
         'conv3',
         builder.Conv2dBNReLU(in_channels=deps[1],
                              out_channels=deps[2],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module(
         'conv4',
         builder.Conv2dBNReLU(in_channels=deps[2],
                              out_channels=deps[3],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module('maxpool2', builder.Maxpool2d(kernel_size=2))
     sq.add_module(
         'conv5',
         builder.Conv2dBNReLU(in_channels=deps[3],
                              out_channels=deps[4],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module(
         'conv6',
         builder.Conv2dBNReLU(in_channels=deps[4],
                              out_channels=deps[5],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module(
         'conv7',
         builder.Conv2dBNReLU(in_channels=deps[5],
                              out_channels=deps[6],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module('maxpool3', builder.Maxpool2d(kernel_size=2))
     sq.add_module(
         'conv8',
         builder.Conv2dBNReLU(in_channels=deps[6],
                              out_channels=deps[7],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module(
         'conv9',
         builder.Conv2dBNReLU(in_channels=deps[7],
                              out_channels=deps[8],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module(
         'conv10',
         builder.Conv2dBNReLU(in_channels=deps[8],
                              out_channels=deps[9],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module('maxpool4', builder.Maxpool2d(kernel_size=2))
     sq.add_module(
         'conv11',
         builder.Conv2dBNReLU(in_channels=deps[9],
                              out_channels=deps[10],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module(
         'conv12',
         builder.Conv2dBNReLU(in_channels=deps[10],
                              out_channels=deps[11],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module(
         'conv13',
         builder.Conv2dBNReLU(in_channels=deps[11],
                              out_channels=deps[12],
                              kernel_size=3,
                              stride=1,
                              padding=1))
     sq.add_module('maxpool5', builder.Maxpool2d(kernel_size=2))
     self.stem = sq
     self.flatten = builder.Flatten()
     self.linear1 = builder.IntermediateLinear(in_features=deps[12],
                                               out_features=512)
     self.relu = builder.ReLU()
     self.linear2 = builder.Linear(in_features=512,
                                   out_features=num_classes)
Ejemplo n.º 26
0
    log_dir = 'acnet_exps/{}_{}_train'.format(network_type, block_type)

    weight_decay_bias = weight_decay_strength
    config = get_baseconfig_by_epoch(network_type=network_type,
                                     dataset_name=get_dataset_name_by_model_name(network_type), dataset_subset='train',
                                     global_batch_size=batch_size, num_node=1,
                                     weight_decay=weight_decay_strength, optimizer_type='sgd', momentum=0.9,
                                     max_epochs=lrs.max_epochs, base_lr=lrs.base_lr, lr_epoch_boundaries=lrs.lr_epoch_boundaries, cosine_minimum=lrs.cosine_minimum,
                                     lr_decay_factor=lrs.lr_decay_factor,
                                     warmup_epochs=0, warmup_method='linear', warmup_factor=0,
                                     ckpt_iter_period=40000, tb_iter_period=100, output_dir=log_dir,
                                     tb_dir=log_dir, save_weights=None, val_epoch_period=5, linear_final_lr=lrs.linear_final_lr,
                                     weight_decay_bias=weight_decay_bias, deps=None)

    if block_type == 'acb':
        builder = ACNetBuilder(base_config=config, deploy=False, gamma_init=gamma_init)
    else:
        builder = ConvBuilder(base_config=config)

    target_weights = os.path.join(log_dir, 'finish.hdf5')
    if not os.path.exists(target_weights):
        train_main(local_rank=start_arg.local_rank, cfg=config, convbuilder=builder,
               show_variables=True, auto_continue=auto_continue)

    if block_type == 'acb' and start_arg.local_rank == 0:
        convert_acnet_weights(target_weights, target_weights.replace('.hdf5', '_deploy.hdf5'), eps=1e-5)
        deploy_builder = ACNetBuilder(base_config=config, deploy=True)
        general_test(network_type=network_type, weights=target_weights.replace('.hdf5', '_deploy.hdf5'),
                 builder=deploy_builder)
Ejemplo n.º 27
0
def csgd_train_main(local_rank,
                    cfg: BaseConfigByEpoch,
                    target_deps,
                    succeeding_strategy,
                    pacesetter_dict,
                    centri_strength,
                    pruned_weights,
                    net=None,
                    train_dataloader=None,
                    val_dataloader=None,
                    show_variables=False,
                    convbuilder=None,
                    init_hdf5=None,
                    no_l2_keywords='depth',
                    use_nesterov=False,
                    load_weights_keyword=None,
                    keyword_to_lr_mult=None,
                    auto_continue=False,
                    save_hdf5_epochs=10000):

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    clusters_save_path = os.path.join(cfg.output_dir, 'clusters.npy')

    with Engine(local_rank=local_rank) as engine:
        engine.setup_log(name='train',
                         log_dir=cfg.output_dir,
                         file_name='log.txt')

        # ----------------------------- build model ------------------------------
        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)
        if net is None:
            net_fn = get_model_fn(cfg.dataset_name, cfg.network_type)
            model = net_fn(cfg, convbuilder)
        else:
            model = net
        model = model.cuda()
        # ----------------------------- model done ------------------------------

        # ---------------------------- prepare data -------------------------
        if train_dataloader is None:
            train_data = create_dataset(cfg.dataset_name,
                                        cfg.dataset_subset,
                                        cfg.global_batch_size,
                                        distributed=engine.distributed)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_data = create_dataset(cfg.dataset_name,
                                      'val',
                                      global_batch_size=100,
                                      distributed=False)
        engine.echo('NOTE: Data prepared')
        engine.echo(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))
        # ----------------------------- data done --------------------------------

        # ------------------------ parepare optimizer, scheduler, criterion -------
        if no_l2_keywords is None:
            no_l2_keywords = []
        if type(no_l2_keywords) is not list:
            no_l2_keywords = [no_l2_keywords]
        # For a target parameter, cancel its weight decay in optimizer, because the weight decay will be later encoded in the decay mat
        conv_idx = 0
        for k, v in model.named_parameters():
            if v.dim() != 4:
                continue
            print('prune {} from {} to {}'.format(conv_idx,
                                                  target_deps[conv_idx],
                                                  cfg.deps[conv_idx]))
            if target_deps[conv_idx] < cfg.deps[conv_idx]:
                no_l2_keywords.append(k.replace(KERNEL_KEYWORD, 'conv'))
                no_l2_keywords.append(k.replace(KERNEL_KEYWORD, 'bn'))
            conv_idx += 1
        print('no l2: ', no_l2_keywords)
        optimizer = get_optimizer(engine,
                                  cfg,
                                  model,
                                  no_l2_keywords=no_l2_keywords,
                                  use_nesterov=use_nesterov,
                                  keyword_to_lr_mult=keyword_to_lr_mult)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()
        # --------------------------------- done -------------------------------

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer)

        if engine.distributed:
            torch.cuda.set_device(local_rank)
            engine.echo('Distributed training, device {}'.format(local_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                broadcast_buffers=False,
            )
        else:
            assert torch.cuda.device_count() == 1
            engine.echo('Single GPU training')

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)
        if init_hdf5:
            engine.load_hdf5(init_hdf5,
                             load_weights_keyword=load_weights_keyword)
        if auto_continue:
            assert cfg.init_weights is None
            engine.load_checkpoint(get_last_checkpoint(cfg.output_dir))
        if show_variables:
            engine.show_variables()

        #   ===================================== prepare the clusters and matrices for C-SGD ==========
        kernel_namedvalue_list = engine.get_all_conv_kernel_namedvalue_as_list(
        )

        if os.path.exists(clusters_save_path):
            layer_idx_to_clusters = np.load(clusters_save_path,
                                            allow_pickle=True).item()
        else:
            if local_rank == 0:
                layer_idx_to_clusters = get_layer_idx_to_clusters(
                    kernel_namedvalue_list=kernel_namedvalue_list,
                    target_deps=target_deps,
                    pacesetter_dict=pacesetter_dict)
                if pacesetter_dict is not None:
                    for follower_idx, pacesetter_idx in pacesetter_dict.items(
                    ):
                        if pacesetter_idx in layer_idx_to_clusters:
                            layer_idx_to_clusters[
                                follower_idx] = layer_idx_to_clusters[
                                    pacesetter_idx]

                np.save(clusters_save_path, layer_idx_to_clusters)
            else:
                while not os.path.exists(clusters_save_path):
                    time.sleep(10)
                    print('sleep, waiting for process 0 to calculate clusters')
                layer_idx_to_clusters = np.load(clusters_save_path,
                                                allow_pickle=True).item()

        param_name_to_merge_matrix = generate_merge_matrix_for_kernel(
            deps=cfg.deps,
            layer_idx_to_clusters=layer_idx_to_clusters,
            kernel_namedvalue_list=kernel_namedvalue_list)
        add_vecs_to_merge_mat_dicts(param_name_to_merge_matrix)
        param_name_to_decay_matrix = generate_decay_matrix_for_kernel_and_vecs(
            deps=cfg.deps,
            layer_idx_to_clusters=layer_idx_to_clusters,
            kernel_namedvalue_list=kernel_namedvalue_list,
            weight_decay=cfg.weight_decay,
            weight_decay_bias=cfg.weight_decay_bias,
            centri_strength=centri_strength)
        print(param_name_to_decay_matrix.keys())
        print(param_name_to_merge_matrix.keys())

        conv_idx = 0
        param_to_clusters = {}
        for k, v in model.named_parameters():
            if v.dim() != 4:
                continue
            if conv_idx in layer_idx_to_clusters:
                for clsts in layer_idx_to_clusters[conv_idx]:
                    if len(clsts) > 1:
                        param_to_clusters[v] = layer_idx_to_clusters[conv_idx]
                        break
            conv_idx += 1
        #   ============================================================================================

        # ------------ do training ---------------------------- #
        engine.log("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch
        last_epoch_done_iters = iteration % iters_per_epoch

        if done_epochs == 0 and last_epoch_done_iters == 0:
            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        recorded_train_time = 0
        recorded_train_examples = 0

        collected_train_loss_sum = 0
        collected_train_loss_count = 0

        for epoch in range(done_epochs, cfg.max_epochs):

            if engine.distributed and hasattr(train_data, 'train_sampler'):
                train_data.train_sampler.set_epoch(epoch)

            if epoch == done_epochs:
                pbar = tqdm(range(iters_per_epoch - last_epoch_done_iters))
            else:
                pbar = tqdm(range(iters_per_epoch))

            if epoch == 0 and local_rank == 0:
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str='Init',
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_data,
                                             dataset_name=cfg.dataset_name)

                # load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                train_net_time_start = time.time()
                acc, acc5, loss = train_one_step(
                    model,
                    data,
                    label,
                    optimizer,
                    criterion,
                    param_name_to_merge_matrix=param_name_to_merge_matrix,
                    param_name_to_decay_matrix=param_name_to_decay_matrix)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                for module in model.modules():
                    if hasattr(module, 'set_cur_iter'):
                        module.set_cur_iter(iteration)

                if iteration % cfg.tb_iter_period == 0 and engine.world_rank == 0:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)
                    deviation_sum = 0
                    for param, clusters in param_to_clusters.items():
                        pvalue = param.detach().cpu().numpy()
                        for cl in clusters:
                            if len(cl) == 1:
                                continue
                            selected = pvalue[cl, :, :, :]
                            mean_kernel = np.mean(selected,
                                                  axis=0,
                                                  keepdims=True)
                            diff = selected - mean_kernel
                            deviation_sum += np.sum(diff**2)
                    tb_writer.add_scalars('deviation_sum',
                                          {'Train': deviation_sum}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                if epoch >= cfg.max_epochs - COLLECT_TRAIN_LOSS_EPOCHS:
                    collected_train_loss_sum += loss.item()
                    collected_train_loss_count += 1

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                iteration += 1

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and
                                                    engine.world_rank == 0):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            engine.update_iteration(iteration)
            engine.save_latest_ckpt(cfg.output_dir)

            if (epoch + 1) % save_hdf5_epochs == 0:
                engine.save_hdf5(
                    os.path.join(cfg.output_dir,
                                 'epoch-{}.hdf5'.format(epoch)))

            if local_rank == 0 and \
                    cfg.val_epoch_period > 0 and (epoch >= cfg.max_epochs - 10 or epoch % cfg.val_epoch_period == 0):
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str=discrip_str,
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            if iteration >= max_iters:
                break

        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters),
                    int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(
                cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
        if collected_train_loss_count > 0:
            engine.log(
                'TRAIN LOSS collected over last {} epochs: {:.6f}'.format(
                    COLLECT_TRAIN_LOSS_EPOCHS,
                    collected_train_loss_sum / collected_train_loss_count))

    if local_rank == 0:
        csgd_prune_and_save(engine=engine,
                            layer_idx_to_clusters=layer_idx_to_clusters,
                            save_file=pruned_weights,
                            succeeding_strategy=succeeding_strategy,
                            new_deps=target_deps)
Ejemplo n.º 28
0
def ding_train(cfg:BaseConfigByEpoch, net=None, train_dataloader=None, val_dataloader=None, show_variables=False, convbuilder=None, beginning_msg=None,
               init_hdf5=None, no_l2_keywords=None, gradient_mask=None, use_nesterov=False):

    # LOCAL_RANK = 0
    #
    # num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    # is_distributed = num_gpus > 1
    #
    # if is_distributed:
    #     torch.cuda.set_device(LOCAL_RANK)
    #     torch.distributed.init_process_group(
    #         backend="nccl", init_method="env://"
    #     )
    #     synchronize()
    #
    # torch.backends.cudnn.benchmark = True

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    with Engine() as engine:

        is_main_process = (engine.world_rank == 0) #TODO correct?

        logger = engine.setup_log(
            name='train', log_dir=cfg.output_dir, file_name='log.txt')

        # -- typical model components model, opt,  scheduler,  dataloder --#
        if net is None:
            net = get_model_fn(cfg.dataset_name, cfg.network_type)

        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)

        model = net(cfg, convbuilder).cuda()

        if train_dataloader is None:
            train_dataloader = create_dataset(cfg.dataset_name, cfg.dataset_subset, cfg.global_batch_size)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_dataloader = create_dataset(cfg.dataset_name, 'val', batch_size=100)    #TODO 100?

        print('NOTE: Data prepared')
        print('NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'.format(cfg.global_batch_size, torch.cuda.device_count(), torch.cuda.memory_allocated()))

        # device = torch.device(cfg.device)
        # model.to(device)
        # model.cuda()

        if no_l2_keywords is None:
            no_l2_keywords = []
        optimizer = get_optimizer(cfg, model, no_l2_keywords=no_l2_keywords, use_nesterov=use_nesterov)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()

        # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")

        engine.register_state(
            scheduler=scheduler, model=model, optimizer=optimizer)

        if engine.distributed:
            print('Distributed training, engine.world_rank={}'.format(engine.world_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[engine.world_rank],
                broadcast_buffers=False, )
            # model = DistributedDataParallel(model, delay_allreduce=True)
        elif torch.cuda.device_count() > 1:
            print('Single machine multiple GPU training')
            model = torch.nn.parallel.DataParallel(model)

        # for k, v in model.named_parameters():
        #     if v.dim() in [2, 4]:
        #         torch.nn.init.xavier_normal_(v)
        #         print('init {} as xavier_normal'.format(k))
        #     if 'bias' in k and 'bn' not in k.lower():
        #         torch.nn.init.zeros_(v)
        #         print('init {} as zero'.format(k))

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights, is_restore=True)

        if init_hdf5:
            engine.load_hdf5(init_hdf5)


        if show_variables:
            engine.show_variables()

        # ------------ do training ---------------------------- #
        if beginning_msg:
            engine.log(beginning_msg)
        logger.info("\n\nStart training with pytorch version {}".format(torch.__version__))

        iteration = engine.state.iteration
        # done_epochs = iteration // num_train_examples_per_epoch(cfg.dataset_name)
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch

        engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        # summary(model=model, input_size=(224, 224) if cfg.dataset_name == 'imagenet' else (32, 32), batch_size=cfg.global_batch_size)

        recorded_train_time = 0
        recorded_train_examples = 0

        if gradient_mask is not None:
            gradient_mask_tensor = {}
            for name, value in gradient_mask.items():
                gradient_mask_tensor[name] = torch.Tensor(value).cuda()
        else:
            gradient_mask_tensor = None

        for epoch in range(done_epochs, cfg.max_epochs):

            pbar = tqdm(range(iters_per_epoch))
            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)


            if cfg.val_epoch_period > 0 and epoch % cfg.val_epoch_period == 0:
                model.eval()
                val_iters = 500 if cfg.dataset_name == 'imagenet' else 100  # use batch_size=100 for val on ImagenNet and CIFAR
                eval_dict, _ = run_eval(val_dataloader, val_iters, model, criterion, discrip_str, dataset_name=cfg.dataset_name)
                val_top1_value = eval_dict['top1'].item()
                val_top5_value = eval_dict['top5'].item()
                val_loss_value = eval_dict['loss'].item()
                for tag, value in zip(tb_tags, [val_top1_value, val_top5_value, val_loss_value]):
                    tb_writer.add_scalars(tag, {'Val': value}, iteration)
                engine.log('validate at epoch {}, top1={:.5f}, top5={:.5f}, loss={:.6f}'.format(epoch, val_top1_value, val_top5_value, val_loss_value))
                model.train()

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)

                train_net_time_start = time.time()
                acc, acc5, loss = train_one_step(model, data, label, optimizer, criterion, if_accum_grad, gradient_mask_tensor=gradient_mask_tensor)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                if iteration % cfg.tb_iter_period == 0 and is_main_process:
                    for tag, value in zip(tb_tags, [acc.item(), acc5.item(), loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)


                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and is_main_process):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                iteration += 1
                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            if iteration >= max_iters:
                break
        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters), int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
Ejemplo n.º 29
0
def csgd_train_and_prune(cfg: BaseConfigByEpoch,
                         target_deps,
                         centri_strength,
                         pacesetter_dict,
                         succeeding_strategy,
                         pruned_weights,
                         net=None,
                         train_dataloader=None,
                         val_dataloader=None,
                         show_variables=False,
                         convbuilder=None,
                         beginning_msg=None,
                         init_hdf5=None,
                         no_l2_keywords=None,
                         use_nesterov=False,
                         tensorflow_style_init=False):

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    clusters_save_path = os.path.join(cfg.output_dir, 'clusters.npy')

    with Engine() as engine:

        is_main_process = (engine.world_rank == 0)  #TODO correct?

        logger = engine.setup_log(name='train',
                                  log_dir=cfg.output_dir,
                                  file_name='log.txt')

        # -- typical model components model, opt,  scheduler,  dataloder --#
        if net is None:
            net = get_model_fn(cfg.dataset_name, cfg.network_type)

        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)

        model = net(cfg, convbuilder).cuda()

        if train_dataloader is None:
            train_dataloader = create_dataset(cfg.dataset_name,
                                              cfg.dataset_subset,
                                              cfg.global_batch_size)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_dataloader = create_dataset(cfg.dataset_name,
                                            'val',
                                            batch_size=100)  #TODO 100?

        print('NOTE: Data prepared')
        print(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))

        optimizer = get_optimizer(cfg, model, use_nesterov=use_nesterov)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()

        # model, optimizer = amp.initialize(model, optimizer, opt_level="O0")

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer,
                              cfg=cfg)

        if engine.distributed:
            print('Distributed training, engine.world_rank={}'.format(
                engine.world_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[engine.world_rank],
                broadcast_buffers=False,
            )
            # model = DistributedDataParallel(model, delay_allreduce=True)
        elif torch.cuda.device_count() > 1:
            print('Single machine multiple GPU training')
            model = torch.nn.parallel.DataParallel(model)

        if tensorflow_style_init:
            for k, v in model.named_parameters():
                if v.dim() in [2, 4]:
                    torch.nn.init.xavier_uniform_(v)
                    print('init {} as xavier_uniform'.format(k))
                if 'bias' in k and 'bn' not in k.lower():
                    torch.nn.init.zeros_(v)
                    print('init {} as zero'.format(k))

        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)

        if init_hdf5:
            engine.load_hdf5(init_hdf5)

        kernel_namedvalue_list = engine.get_all_conv_kernel_namedvalue_as_list(
        )

        if os.path.exists(clusters_save_path):
            layer_idx_to_clusters = np.load(clusters_save_path).item()
        else:
            layer_idx_to_clusters = get_layer_idx_to_clusters(
                kernel_namedvalue_list=kernel_namedvalue_list,
                target_deps=target_deps,
                pacesetter_dict=pacesetter_dict)
            if pacesetter_dict is not None:
                for follower_idx, pacesetter_idx in pacesetter_dict.items():
                    if pacesetter_idx in layer_idx_to_clusters:
                        layer_idx_to_clusters[
                            follower_idx] = layer_idx_to_clusters[
                                pacesetter_idx]

            np.save(clusters_save_path, layer_idx_to_clusters)

        csgd_save_file = os.path.join(cfg.output_dir, 'finish.hdf5')

        if os.path.exists(csgd_save_file):
            engine.load_hdf5(csgd_save_file)
        else:
            param_name_to_merge_matrix = generate_merge_matrix_for_kernel(
                deps=cfg.deps,
                layer_idx_to_clusters=layer_idx_to_clusters,
                kernel_namedvalue_list=kernel_namedvalue_list)
            param_name_to_decay_matrix = generate_decay_matrix_for_kernel_and_vecs(
                deps=cfg.deps,
                layer_idx_to_clusters=layer_idx_to_clusters,
                kernel_namedvalue_list=kernel_namedvalue_list,
                weight_decay=cfg.weight_decay,
                centri_strength=centri_strength)
            # if pacesetter_dict is not None:
            #     for follower_idx, pacesetter_idx in pacesetter_dict.items():
            #         follower_kernel_name = kernel_namedvalue_list[follower_idx].name
            #         pacesetter_kernel_name = kernel_namedvalue_list[follower_idx].name
            #         if pacesetter_kernel_name in param_name_to_merge_matrix:
            #             param_name_to_merge_matrix[follower_kernel_name] = param_name_to_merge_matrix[
            #                 pacesetter_kernel_name]
            #             param_name_to_decay_matrix[follower_kernel_name] = param_name_to_decay_matrix[
            #                 pacesetter_kernel_name]

            add_vecs_to_mat_dicts(param_name_to_merge_matrix)

            if show_variables:
                engine.show_variables()

            if beginning_msg:
                engine.log(beginning_msg)

            logger.info("\n\nStart training with pytorch version {}".format(
                torch.__version__))

            iteration = engine.state.iteration
            # done_epochs = iteration // num_train_examples_per_epoch(cfg.dataset_name)
            iters_per_epoch = num_iters_per_epoch(cfg)
            max_iters = iters_per_epoch * cfg.max_epochs
            tb_writer = SummaryWriter(cfg.tb_dir)
            tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

            model.train()

            done_epochs = iteration // iters_per_epoch

            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

            recorded_train_time = 0
            recorded_train_examples = 0

            for epoch in range(done_epochs, cfg.max_epochs):

                pbar = tqdm(range(iters_per_epoch))
                top1 = AvgMeter()
                top5 = AvgMeter()
                losses = AvgMeter()
                discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
                pbar.set_description('Train' + discrip_str)

                if cfg.val_epoch_period > 0 and epoch % cfg.val_epoch_period == 0:
                    model.eval()
                    val_iters = 500 if cfg.dataset_name == 'imagenet' else 100  # use batch_size=100 for val on ImagenNet and CIFAR
                    eval_dict, _ = run_eval(val_dataloader,
                                            val_iters,
                                            model,
                                            criterion,
                                            discrip_str,
                                            dataset_name=cfg.dataset_name)
                    val_top1_value = eval_dict['top1'].item()
                    val_top5_value = eval_dict['top5'].item()
                    val_loss_value = eval_dict['loss'].item()
                    for tag, value in zip(
                            tb_tags,
                        [val_top1_value, val_top5_value, val_loss_value]):
                        tb_writer.add_scalars(tag, {'Val': value}, iteration)
                    engine.log(
                        'validate at epoch {}, top1={:.5f}, top5={:.5f}, loss={:.6f}'
                        .format(epoch, val_top1_value, val_top5_value,
                                val_loss_value))
                    model.train()

                for _ in pbar:

                    start_time = time.time()
                    data, label = load_cuda_data(train_dataloader,
                                                 cfg.dataset_name)
                    data_time = time.time() - start_time

                    train_net_time_start = time.time()
                    acc, acc5, loss = train_one_step(
                        model,
                        data,
                        label,
                        optimizer,
                        criterion,
                        param_name_to_merge_matrix=param_name_to_merge_matrix,
                        param_name_to_decay_matrix=param_name_to_decay_matrix)
                    train_net_time_end = time.time()

                    if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                        recorded_train_examples += cfg.global_batch_size
                        recorded_train_time += train_net_time_end - train_net_time_start

                    scheduler.step()

                    if iteration % cfg.tb_iter_period == 0 and is_main_process:
                        for tag, value in zip(
                                tb_tags,
                            [acc.item(), acc5.item(),
                             loss.item()]):
                            tb_writer.add_scalars(tag, {'Train': value},
                                                  iteration)

                    top1.update(acc.item())
                    top5.update(acc5.item())
                    losses.update(loss.item())

                    pbar_dic = OrderedDict()
                    pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                    pbar_dic['cur_iter'] = iteration
                    pbar_dic['lr'] = scheduler.get_lr()[0]
                    pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                    pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                    pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                    pbar.set_postfix(pbar_dic)

                    if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                        engine.update_iteration(iteration)
                        if (not engine.distributed) or (engine.distributed
                                                        and is_main_process):
                            engine.save_and_link_checkpoint(cfg.output_dir)

                    iteration += 1
                    if iteration >= max_iters:
                        break

                #   do something after an epoch?
                if iteration >= max_iters:
                    break
            #   do something after the training
            if recorded_train_time > 0:
                exp_per_sec = recorded_train_examples / recorded_train_time
            else:
                exp_per_sec = 0
            engine.log(
                'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
                .format(int(TRAIN_SPEED_START * max_iters),
                        int(TRAIN_SPEED_END * max_iters),
                        cfg.global_batch_size, recorded_train_examples,
                        recorded_train_time, exp_per_sec))
            if cfg.save_weights:
                engine.save_checkpoint(cfg.save_weights)
                print('NOTE: training finished, saved to {}'.format(
                    cfg.save_weights))
            engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))

        csgd_prune_and_save(engine=engine,
                            layer_idx_to_clusters=layer_idx_to_clusters,
                            save_file=pruned_weights,
                            succeeding_strategy=succeeding_strategy,
                            new_deps=target_deps)
Ejemplo n.º 30
0
def train_main(local_rank,
               cfg: BaseConfigByEpoch,
               net=None,
               train_dataloader=None,
               val_dataloader=None,
               show_variables=False,
               convbuilder=None,
               init_hdf5=None,
               no_l2_keywords='depth',
               gradient_mask=None,
               use_nesterov=False,
               tensorflow_style_init=False,
               load_weights_keyword=None,
               keyword_to_lr_mult=None,
               auto_continue=False,
               lasso_keyword_to_strength=None,
               save_hdf5_epochs=10000):

    if no_l2_keywords is None:
        no_l2_keywords = []
    if type(no_l2_keywords) is not list:
        no_l2_keywords = [no_l2_keywords]

    ensure_dir(cfg.output_dir)
    ensure_dir(cfg.tb_dir)
    with Engine(local_rank=local_rank) as engine:
        engine.setup_log(name='train',
                         log_dir=cfg.output_dir,
                         file_name='log.txt')

        # ----------------------------- build model ------------------------------
        if convbuilder is None:
            convbuilder = ConvBuilder(base_config=cfg)
        if net is None:
            net_fn = get_model_fn(cfg.dataset_name, cfg.network_type)
            model = net_fn(cfg, convbuilder)
        else:
            model = net
        model = model.cuda()
        # ----------------------------- model done ------------------------------

        # ---------------------------- prepare data -------------------------
        if train_dataloader is None:
            train_data = create_dataset(cfg.dataset_name,
                                        cfg.dataset_subset,
                                        cfg.global_batch_size,
                                        distributed=engine.distributed)
        if cfg.val_epoch_period > 0 and val_dataloader is None:
            val_data = create_dataset(cfg.dataset_name,
                                      'val',
                                      global_batch_size=100,
                                      distributed=False)
        engine.echo('NOTE: Data prepared')
        engine.echo(
            'NOTE: We have global_batch_size={} on {} GPUs, the allocated GPU memory is {}'
            .format(cfg.global_batch_size, torch.cuda.device_count(),
                    torch.cuda.memory_allocated()))
        # ----------------------------- data done --------------------------------

        # ------------------------ parepare optimizer, scheduler, criterion -------
        optimizer = get_optimizer(engine,
                                  cfg,
                                  model,
                                  no_l2_keywords=no_l2_keywords,
                                  use_nesterov=use_nesterov,
                                  keyword_to_lr_mult=keyword_to_lr_mult)
        scheduler = get_lr_scheduler(cfg, optimizer)
        criterion = get_criterion(cfg).cuda()
        # --------------------------------- done -------------------------------

        engine.register_state(scheduler=scheduler,
                              model=model,
                              optimizer=optimizer)

        if engine.distributed:
            torch.cuda.set_device(local_rank)
            engine.echo('Distributed training, device {}'.format(local_rank))
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                broadcast_buffers=False,
            )
        else:
            assert torch.cuda.device_count() == 1
            engine.echo('Single GPU training')

        if tensorflow_style_init:
            init_as_tensorflow(model)
        if cfg.init_weights:
            engine.load_checkpoint(cfg.init_weights)
        if init_hdf5:
            engine.load_hdf5(init_hdf5,
                             load_weights_keyword=load_weights_keyword)
        if auto_continue:
            assert cfg.init_weights is None
            engine.load_checkpoint(get_last_checkpoint(cfg.output_dir))
        if show_variables:
            engine.show_variables()

        # ------------ do training ---------------------------- #
        engine.log("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        iteration = engine.state.iteration
        iters_per_epoch = num_iters_per_epoch(cfg)
        max_iters = iters_per_epoch * cfg.max_epochs
        tb_writer = SummaryWriter(cfg.tb_dir)
        tb_tags = ['Top1-Acc', 'Top5-Acc', 'Loss']

        model.train()

        done_epochs = iteration // iters_per_epoch
        last_epoch_done_iters = iteration % iters_per_epoch

        if done_epochs == 0 and last_epoch_done_iters == 0:
            engine.save_hdf5(os.path.join(cfg.output_dir, 'init.hdf5'))

        recorded_train_time = 0
        recorded_train_examples = 0

        collected_train_loss_sum = 0
        collected_train_loss_count = 0

        if gradient_mask is not None:
            gradient_mask_tensor = {}
            for name, value in gradient_mask.items():
                gradient_mask_tensor[name] = torch.Tensor(value).cuda()
        else:
            gradient_mask_tensor = None

        for epoch in range(done_epochs, cfg.max_epochs):

            if engine.distributed and hasattr(train_data, 'train_sampler'):
                train_data.train_sampler.set_epoch(epoch)

            if epoch == done_epochs:
                pbar = tqdm(range(iters_per_epoch - last_epoch_done_iters))
            else:
                pbar = tqdm(range(iters_per_epoch))

            if epoch == 0 and local_rank == 0:
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str='Init',
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            top1 = AvgMeter()
            top5 = AvgMeter()
            losses = AvgMeter()
            discrip_str = 'Epoch-{}/{}'.format(epoch, cfg.max_epochs)
            pbar.set_description('Train' + discrip_str)

            for _ in pbar:

                start_time = time.time()
                data, label = load_cuda_data(train_data,
                                             dataset_name=cfg.dataset_name)

                # load_cuda_data(train_dataloader, cfg.dataset_name)
                data_time = time.time() - start_time

                if_accum_grad = ((iteration % cfg.grad_accum_iters) != 0)

                train_net_time_start = time.time()
                acc, acc5, loss = train_one_step(
                    model,
                    data,
                    label,
                    optimizer,
                    criterion,
                    if_accum_grad,
                    gradient_mask_tensor=gradient_mask_tensor,
                    lasso_keyword_to_strength=lasso_keyword_to_strength)
                train_net_time_end = time.time()

                if iteration > TRAIN_SPEED_START * max_iters and iteration < TRAIN_SPEED_END * max_iters:
                    recorded_train_examples += cfg.global_batch_size
                    recorded_train_time += train_net_time_end - train_net_time_start

                scheduler.step()

                for module in model.modules():
                    if hasattr(module, 'set_cur_iter'):
                        module.set_cur_iter(iteration)

                if iteration % cfg.tb_iter_period == 0 and engine.world_rank == 0:
                    for tag, value in zip(
                            tb_tags,
                        [acc.item(), acc5.item(),
                         loss.item()]):
                        tb_writer.add_scalars(tag, {'Train': value}, iteration)

                top1.update(acc.item())
                top5.update(acc5.item())
                losses.update(loss.item())

                if epoch >= cfg.max_epochs - COLLECT_TRAIN_LOSS_EPOCHS:
                    collected_train_loss_sum += loss.item()
                    collected_train_loss_count += 1

                pbar_dic = OrderedDict()
                pbar_dic['data-time'] = '{:.2f}'.format(data_time)
                pbar_dic['cur_iter'] = iteration
                pbar_dic['lr'] = scheduler.get_lr()[0]
                pbar_dic['top1'] = '{:.5f}'.format(top1.mean)
                pbar_dic['top5'] = '{:.5f}'.format(top5.mean)
                pbar_dic['loss'] = '{:.5f}'.format(losses.mean)
                pbar.set_postfix(pbar_dic)

                iteration += 1

                if iteration >= max_iters or iteration % cfg.ckpt_iter_period == 0:
                    engine.update_iteration(iteration)
                    if (not engine.distributed) or (engine.distributed and
                                                    engine.world_rank == 0):
                        engine.save_and_link_checkpoint(cfg.output_dir)

                if iteration >= max_iters:
                    break

            #   do something after an epoch?
            engine.update_iteration(iteration)
            engine.save_latest_ckpt(cfg.output_dir)

            if (epoch + 1) % save_hdf5_epochs == 0:
                engine.save_hdf5(
                    os.path.join(cfg.output_dir,
                                 'epoch-{}.hdf5'.format(epoch)))

            if local_rank == 0 and \
                    cfg.val_epoch_period > 0 and (epoch >= cfg.max_epochs - 10 or epoch % cfg.val_epoch_period == 0):
                val_during_train(epoch=epoch,
                                 iteration=iteration,
                                 tb_tags=tb_tags,
                                 engine=engine,
                                 model=model,
                                 val_data=val_data,
                                 criterion=criterion,
                                 descrip_str=discrip_str,
                                 dataset_name=cfg.dataset_name,
                                 test_batch_size=TEST_BATCH_SIZE,
                                 tb_writer=tb_writer)

            if iteration >= max_iters:
                break

        #   do something after the training
        if recorded_train_time > 0:
            exp_per_sec = recorded_train_examples / recorded_train_time
        else:
            exp_per_sec = 0
        engine.log(
            'TRAIN speed: from {} to {} iterations, batch_size={}, examples={}, total_net_time={:.4f}, examples/sec={}'
            .format(int(TRAIN_SPEED_START * max_iters),
                    int(TRAIN_SPEED_END * max_iters), cfg.global_batch_size,
                    recorded_train_examples, recorded_train_time, exp_per_sec))
        if cfg.save_weights:
            engine.save_checkpoint(cfg.save_weights)
            print('NOTE: training finished, saved to {}'.format(
                cfg.save_weights))
        engine.save_hdf5(os.path.join(cfg.output_dir, 'finish.hdf5'))
        if collected_train_loss_count > 0:
            engine.log(
                'TRAIN LOSS collected over last {} epochs: {:.6f}'.format(
                    COLLECT_TRAIN_LOSS_EPOCHS,
                    collected_train_loss_sum / collected_train_loss_count))