Exemplo n.º 1
0
    def _make_layer(self, block, planes, blocks, expansion=1, stride=1, groups=1, residual_block=None, dropout=None, mixup=False, BN_momentum = 0.1, conv_options = None, fixup = False):
        downsample = None
        out_planes = planes * expansion
        if stride != 1 or self.inplanes != out_planes:

            if fixup:
                downsample = conv1x1(self.inplanes, out_planes, stride,  conv_options = conv_options)
            else: 
                downsample = nn.Sequential(
                    # nn.Conv2d(self.inplanes, out_planes,
                    #           kernel_size=1, stride=stride, bias=False),
                    conv1x1(self.inplanes, out_planes, stride,  conv_options = conv_options),
                    nn.BatchNorm2d(planes * expansion, BN_momentum),
                )
        if residual_block is not None:
            residual_block = residual_block(out_planes)

        layers = []
        layers.append(block(self.inplanes, planes, stride, expansion=expansion,
                            downsample=downsample, groups=groups, residual_block=residual_block, dropout=dropout, BN_momentum = BN_momentum, conv_options=conv_options))
        self.inplanes = planes * expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, expansion=expansion, groups=groups,
                                residual_block=residual_block, dropout=dropout, BN_momentum = BN_momentum, conv_options= conv_options))
        if mixup:
            layers.append(MixUp())
        return nn.Sequential(*layers)
Exemplo n.º 2
0
def create_mixup_or_none(alpha, num_classes, comm):
    from utils.mixup import MixUp
    # Create different random generators over workers.
    rng = np.random.RandomState(726 + comm.rank)
    if alpha > 0:
        return MixUp(alpha, num_classes, rng)
    return None
Exemplo n.º 3
0
    def _make_layer(self,
                    block,
                    planes,
                    blocks,
                    expansion=1,
                    stride=1,
                    groups=1,
                    residual_block=None,
                    dropout=None,
                    mixup=False,
                    dp_type=None,
                    dp_percentage=None,
                    dev='cpu'):
        downsample = None
        out_planes = planes * expansion
        if stride != 1 or self.inplanes != out_planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,
                          out_planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                nn.BatchNorm2d(planes * expansion),
            )
        if residual_block is not None:
            residual_block = residual_block(out_planes)

        layers = []
        layers.append(
            block(self.inplanes,
                  planes,
                  stride,
                  expansion=expansion,
                  downsample=downsample,
                  groups=groups,
                  residual_block=residual_block,
                  dropout=dropout,
                  dropout_type=dp_type,
                  drop_percentage=dp_percentage,
                  dev=dev))
        self.inplanes = planes * expansion
        for i in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      expansion=expansion,
                      groups=groups,
                      residual_block=residual_block,
                      dropout=dropout,
                      dropout_type=dp_type,
                      drop_percentage=dp_percentage,
                      dev=dev))
        if mixup:
            layers.append(MixUp())
        return nn.Sequential(*layers)
Exemplo n.º 4
0
    def _step(self, inputs_batch, target_batch, training=False, chunk_batch=1):
        outputs = []
        total_loss = 0

        if training:
            self.optimizer.zero_grad()
            self.optimizer.update(self.epoch, self.training_steps)

        for inputs, target in zip(inputs_batch.chunk(chunk_batch, dim=0),
                                  target_batch.chunk(chunk_batch, dim=0)):
            target = target.to(self.device)
            inputs = inputs.to(self.device, dtype=self.dtype)

            mixup = None
            if training:
                self.optimizer.pre_forward()
                if self.mixup is not None:
                    input_mixup = MixUp()
                    mixup_modules = [input_mixup]  # input mixup
                    mixup_modules += [m for m in self.model.modules()
                                      if isinstance(m, MixUp)]
                    mixup = _mixup(mixup_modules, self.mixup, inputs.size(0))
                    inputs = input_mixup(inputs)

            # compute output
            output = self.model(inputs)
            if mixup is not None:
                target = mixup.mix_target(target, output.size(-1))
            loss = self.criterion(output, target)
            grad = None

            if chunk_batch > 1:
                loss = loss / chunk_batch

            if isinstance(output, list) or isinstance(output, tuple):
                output = output[0]

            outputs.append(output.detach())

            if training:
                self.optimizer.pre_backward()
                if self.grad_scale is not None:
                    loss = loss * self.grad_scale
                loss.backward()   # accumulate gradient

            total_loss += float(loss)

        if training:  # post gradient accumulation
            if self.grad_clip > 0:
                grad = clip_grad_norm_(self.model.parameters(), self.grad_clip)
            self.optimizer.step()  # SGD step
            self.training_steps += 1

        outputs = torch.cat(outputs, dim=0)
        return outputs, total_loss, grad
Exemplo n.º 5
0
    def _make_layer(self, block, planes, blocks, expansion=1, stride=1, groups=1, residual_block=None, dropout=None, mixup=False):
        CBN = get_bn_folding_module(nn.Conv2d,nn.BatchNorm2d)
        downsample = None
        out_planes = planes * expansion
        if stride != 1 or self.inplanes != out_planes:
            downsample = nn.Sequential(CBN(self.inplanes, out_planes,kernel_size=1, stride=stride, bias=True))

        if residual_block is not None:
            residual_block = residual_block(out_planes)

        layers = []
        layers.append(block(self.inplanes, planes, stride, expansion=expansion,
                            downsample=downsample, groups=groups, residual_block=residual_block, dropout=dropout))
        self.inplanes = planes * expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, expansion=expansion, groups=groups,
                                residual_block=residual_block, dropout=dropout))
        if mixup:
            layers.append(MixUp())
        return nn.Sequential(*layers)
Exemplo n.º 6
0
    def _make_layer(self, block, planes, blocks, expansion=1, stride=1, groups=1, residual_block=None, dropout=None, mixup=False,absorb_bn=False):
        downsample = None
        out_planes = planes * expansion
        if stride != 1 or self.inplanes != out_planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, out_planes,
                          kernel_size=1, stride=stride, bias=absorb_bn), )
            if not absorb_bn:
                downsample.add_module('1' ,nn.BatchNorm2d(planes * expansion))
        if residual_block is not None:
            residual_block = residual_block(out_planes)

        layers = []
        layers.append(block(self.inplanes, planes, stride, expansion=expansion,
                            downsample=downsample, groups=groups, residual_block=residual_block, dropout=dropout,absorb_bn=absorb_bn))
        self.inplanes = planes * expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, expansion=expansion, groups=groups,
                                residual_block=residual_block, dropout=dropout,absorb_bn=absorb_bn))
        if mixup:
            layers.append(MixUp())
        return nn.Sequential(*layers)
    def _step(self, inputs_batch, target_batch, training=False, average_output=False, chunk_batch=1, scheduled_instructions=None):
        outputs = []
        total_loss = 0
        # self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) and (float(self.epoch) % 2 == 0) else False
        # if scheduled_instructions is None:
        #     self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) else False
        # else:
        #     self.input_grad_statistics = scheduled_instructions['collect_stat']

        meters = {name: {'mean': AverageMeter(), 'std': AverageMeter()} for name in self.module_to_hook.keys()}
        grad_log_stats = {}

        if training:
            self.optimizer.zero_grad()
            self.optimizer.update(self.epoch, self.training_steps)

        for i, (inputs, target) in enumerate(zip(inputs_batch.chunk(chunk_batch, dim=0),
                                                 target_batch.chunk(chunk_batch, dim=0))):

            self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) and (float(self.epoch) % 2 == 0) and i==0 else False
            if training:
                if self.input_grad_statistics:
                    BK_hookF = {}
                    for name, module in self.module_to_hook.items():
                        # print("name: " + str(name))
                        BK_hookF[name] = Hook(module, name, True)

            target = target.to(self.device)
            inputs = inputs.to(self.device, dtype=self.dtype)

            mixup = None
            if training:
                self.optimizer.pre_forward()
                if self.mixup is not None or self.cutmix is not None:
                    input_mixup = CutMix() if self.cutmix else MixUp()
                    mix_val = self.mixup or self.cutmix
                    mixup_modules = [input_mixup]  # input mixup
                    mixup_modules += [m for m in self.model.modules()
                                      if isinstance(m, MixUp)]
                    mixup = _mixup(mixup_modules, mix_val, inputs.size(0))
                    inputs = input_mixup(inputs)

            # compute output
            output = self.model(inputs)

            if mixup is not None:
                target = mixup.mix_target(target, output.size(-1))

            if average_output:
                if isinstance(output, list) or isinstance(output, tuple):
                    output = [_average_duplicates(out, target) if out is not None else None
                              for out in output]
                else:
                    output = _average_duplicates(output, target)
            loss = self.criterion(output, target)
            grad = None

            if chunk_batch > 1:
                loss = loss / chunk_batch

            if isinstance(output, list) or isinstance(output, tuple):
                output = output[0]

            outputs.append(output.detach())
            total_loss += float(loss)

            if training:
                if i == 0:
                    self.optimizer.pre_backward()
                if self.grad_scale is not None:
                    loss = loss * self.grad_scale
                if self.loss_scale is not None:
                    pass  # moran
                    # pdb.set_trace()
                    # loss = loss * self.loss_scale
                loss.backward()   # accumulate gradient


                ## moran- gather lognorm statistics


                if self.input_grad_statistics:
                    # pdb.set_trace()
                    # total = reduce((lambda total, item: total+item.alloc), snds)
                    for hook in BK_hookF.values():
                        meters[hook.name]['mean'].update(float(hook.grad_log_mean), inputs.size(0))
                        meters[hook.name]['std'].update(float(hook.grad_log_std), inputs.size(0))
                    # curr_grad_log_stats = {hook.name: {'mean': hook.grad_log_mean, 'std': hook.grad_log_std} for hook in BK_hookF.values()}
                    # pdb.set_trace()

                if self.input_grad_statistics:
                    for hook in BK_hookF.values():
                        hook.close()
                ## moran- end



        grad_log_stats = {name: {'mean': met['mean'].avg, 'std': met['std'].avg} for name, met in meters.items()}
        # print("grad_log_stats! loop 1")
        if training:  # post gradient accumulation
            if self.loss_scale is not None:
                for p in self.model.parameters():
                    if p.grad is None:
                        continue
                    # p.grad.data.div_(self.loss_scale)  # moran

            if self.grad_clip > 0:
                grad = clip_grad_norm_(self.model.parameters(), self.grad_clip)
            # dont_update_grad = False
            # for param in self.model.parameters():
            #     param.grad[torch.isnan(param.grad)] = 0
                # if torch.isnan(param.grad).any():  # moran
                #     pdb.set_trace()
                #     dont_update_grad = True
                #     break
            # if dont_update_grad:
            #     pass
            # else:
            #     self.optimizer.step()  # SGD step
            self.optimizer.step()  # SGD step
            self.training_steps += 1

        outputs = torch.cat(outputs, dim=0)
        return outputs, total_loss, grad, grad_log_stats
Exemplo n.º 8
0
    def _step(self,
              inputs_batch,
              target_batch,
              training=False,
              average_output=False,
              chunk_batch=1,
              scheduled_instructions=None,
              iter=0):
        outputs = []
        total_loss = 0
        grad_log_stats = {}
        # self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) and (float(self.epoch) % 2 == 0) else False
        # if scheduled_instructions is None:
        #     self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) else False
        # else:
        #     self.input_grad_statistics = scheduled_instructions['collect_stat']

        meters = {
            name: {
                'mean': AverageMeter(),
                'std': AverageMeter()
            }
            for name in self.module_to_hook.keys()
        }
        grad_log_stats = {}
        # pdb.set_trace()
        if self.epoch == 0:
            for name, module in self.module_to_hook.items():
                if hasattr(module, 'enable_grad_quantizer'):
                    module.enable_grad_quantizer = False  # this is the default start behavior
                    module.fp_x.fill_(self.fp_bits)

        if training:
            self.optimizer.zero_grad()
            self.optimizer.update(self.epoch, self.training_steps)

        for i, (inputs, target) in enumerate(
                zip(inputs_batch.chunk(chunk_batch, dim=0),
                    target_batch.chunk(chunk_batch, dim=0))):

            # pdb.set_trace()
            # self.input_grad_statistics = True if ((self.epoch > -1) and self.enable_input_grad_statistics) and (float(self.epoch) % 2 == 0) and i==0 else False
            if training:

                BK_hookF = {}
                # if (self.epoch != 0) and (iter == 0):
                if (self.epoch != 0) and (iter in self.iters_in_fp32):

                    # print("RULE 0")
                    # pdb.set_trace()
                    for name, module in self.module_to_hook.items():
                        if hasattr(module, 'enable_grad_quantizer'):
                            if iter == self.iters_in_fp32[0]:
                                module.mu.reset()
                            # if self.epoch == 1:
                            module.enable_grad_quantizer = False  # this is the default start behavior
                            # else:
                            #     module.enable_grad_quantizer = True
                            BK_hookF[name] = Hook(module, name, True,
                                                  'collect and quantize')
                        else:
                            BK_hookF[name] = Hook(module, name, True,
                                                  'collect only')
                # elif (self.epoch != 0) and (iter == 1):
                elif (self.epoch != 0) and (iter in self.iters_wram_up_with_q):
                    # print("RULE 1")
                    # pdb.set_trace()
                    # self.highest_exp_bits = 0
                    for name, module in self.module_to_hook.items():
                        if hasattr(module, 'enable_grad_quantizer'):
                            module.enable_grad_quantizer = True  # this is the default start behavior
                            # if self.epoch in [1,2,3,4,9,10,20,30,40,50,60,70,78,79,80,81,82,83,84,85,86,90,100,110,120]:
                            # if self.epoch in [1, 2, 3, 4, 30, 31, 32,33, 60, 61, 62, 63, 90,91,92,93]:
                            # if self.epoch in [1,2,3,4, 80,81,82,83]:  # 30, 31, 32, 60, 61, 62, 90, 91, 92]:
                            # if self.epoch in [1, 2, 3, 4, 30, 31, 32, 60, 61, 62, 80, 81, 82]:
                            # module.loss_scale.fill_(2 ** (-module.mu-3))  # update loss scale
                            module.loss_scale.fill_(
                                2**(-module.mu.avg - 3))  # update loss scale

                            # self.highest_exp_bits = module.fp8_grad_args['exp_width'] if module.fp8_grad_args['exp_width'] > self.highest_exp_bits else self.highest_exp_bits
                            # self.highest_exp_bits = 4
                            # module.enable_grad_quantizer = False  # this is the default start behavior
                            # print("name: " + str(name))
                            BK_hookF[name] = Hook(module, name, True,
                                                  'collect and quantize')
                        else:
                            BK_hookF[name] = Hook(module, name, True,
                                                  'collect only')
                    for name, module in self.module_to_hook.items():
                        if hasattr(module, 'enable_grad_quantizer'):
                            if iter == self.iters_wram_up_with_q[0]:
                                module.mu.reset()
                            module.fp8_grad_args = dict(
                                exp_width=self.highest_exp_bits,
                                man_width=int(
                                    (module.fp_x - 1 - self.highest_exp_bits)),
                                exp_bias=(2**(self.highest_exp_bits - 1)) - 1,
                                roundingMode=0,
                                lfsrVal=0)

                # elif (self.epoch != 0) and (iter > 1):
                elif (self.epoch != 0) and (iter >
                                            self.iters_wram_up_with_q[-1]):
                    # print("RULE 2")
                    # pdb.set_trace()
                    # self.highest_exp_bits = 0
                    for name, module in self.module_to_hook.items():
                        if hasattr(module, 'enable_grad_quantizer'):
                            # module.enable_grad_quantizer = False  # this is the default start behavior
                            module.enable_grad_quantizer = True  # this is the default start behavior
                            # if self.epoch in [1,2,3,4, 9, 10, 20, 30, 40, 50, 60, 70, 78, 79, 80, 81, 82, 83, 84, 85, 86, 90,
                            #                   100, 110, 120]:
                            # if self.epoch in [1, 2, 3, 4, 79, 80, 81, 82, 83, 84]:
                            # if self.epoch in [1, 2, 3, 4, 30, 31, 32, 33, 60, 61, 62, 63, 90, 91, 92, 93]:
                            # if self.epoch in [1, 30, 31, 32, 60, 61, 62, 90, 91, 92]:
                            # self.highest_exp_bits = module.fp8_grad_args['exp_width'] if module.fp8_grad_args['exp_width'] > self.highest_exp_bits else self.highest_exp_bits
                            # self.highest_exp_bits = 4
                            # if self.epoch in [1, 2, 3, 4, 80, 81, 82, 83]:  # 30, 31, 32, 60, 61, 62, 90, 91, 92]:
                            # if self.epoch in [1, 2, 3, 4, 30, 31, 32, 60, 61, 62, 80, 81, 82]:
                            # module.loss_scale.fill_(2 ** (-module.mu-3))  # update loss scale
                            module.loss_scale.fill_(
                                2**(-module.mu.avg - 3))  # update loss scale
                    for name, module in self.module_to_hook.items():
                        if hasattr(module, 'enable_grad_quantizer'):
                            module.fp8_grad_args = dict(
                                exp_width=self.highest_exp_bits,
                                man_width=int(
                                    (module.fp_x - 1 - self.highest_exp_bits)),
                                exp_bias=(2**(self.highest_exp_bits - 1)) - 1,
                                roundingMode=0,
                                lfsrVal=0)
                elif self.epoch == 0:
                    # print("RULE 3")
                    for name, module in self.module_to_hook.items():
                        BK_hookF[name] = Hook(module, name, True,
                                              'collect only')

            target = target.to(self.device)
            inputs = inputs.to(self.device, dtype=self.dtype)

            mixup = None
            if training:
                self.optimizer.pre_forward()
                if self.mixup is not None or self.cutmix is not None:
                    input_mixup = CutMix() if self.cutmix else MixUp()
                    mix_val = self.mixup or self.cutmix
                    mixup_modules = [input_mixup]  # input mixup
                    mixup_modules += [
                        m for m in self.model.modules()
                        if isinstance(m, MixUp)
                    ]
                    mixup = _mixup(mixup_modules, mix_val, inputs.size(0))
                    inputs = input_mixup(inputs)

            # compute output
            output = self.model(inputs)

            if mixup is not None:
                target = mixup.mix_target(target, output.size(-1))

            if average_output:
                if isinstance(output, list) or isinstance(output, tuple):
                    output = [
                        _average_duplicates(out, target)
                        if out is not None else None for out in output
                    ]
                else:
                    output = _average_duplicates(output, target)
            loss = self.criterion(output, target)
            grad = None

            if chunk_batch > 1:
                loss = loss / chunk_batch

            if isinstance(output, list) or isinstance(output, tuple):
                output = output[0]

            outputs.append(output.detach())
            total_loss += float(loss)

            if training:
                if i == 0:
                    self.optimizer.pre_backward()
                if self.grad_scale is not None:
                    pass
                    # loss = loss * self.grad_scale
                if self.loss_scale is not None:
                    pass  # moran
                    # pdb.set_trace()
                    # loss = loss * self.loss_scale
                loss.backward()  # accumulate gradient

                ## moran- gather lognorm statistics

                if BK_hookF:
                    # pdb.set_trace()
                    # total = reduce((lambda total, item: total+item.alloc), snds)
                    for hook in BK_hookF.values():
                        meters[hook.name]['mean'].update(
                            float(hook.grad_log_mean), inputs.size(0))
                        meters[hook.name]['std'].update(
                            float(hook.grad_log_std), inputs.size(0))
                    # curr_grad_log_stats = {hook.name: {'mean': hook.grad_log_mean, 'std': hook.grad_log_std} for hook in BK_hookF.values()}
                    # pdb.set_trace()

                    grad_log_stats = {
                        name: {
                            'mean': met['mean'].avg,
                            'std': met['std'].avg
                        }
                        for name, met in meters.items()
                    }
                if BK_hookF:
                    for hook in BK_hookF.values():
                        hook.close()
                ## moran- end

        # print("grad_log_stats! loop 1")
        if training:  # post gradient accumulation
            if self.loss_scale is not None:
                for p in self.model.parameters():
                    if p.grad is None:
                        continue
                    # p.grad.data.div_(self.loss_scale)  # moran

            if self.grad_clip > 0:
                grad = clip_grad_norm_(self.model.parameters(), self.grad_clip)
            # dont_update_grad = False
            # for param in self.model.parameters():
            #     param.grad[torch.isnan(param.grad)] = 0
            # if torch.isnan(param.grad).any():  # moran
            #     pdb.set_trace()
            #     dont_update_grad = True
            #     break
            # if dont_update_grad:
            #     pass
            # else:
            #     self.optimizer.step()  # SGD step
            if (self.epoch != 0) and (iter == 0):
                pass
            else:
                self.optimizer.step()  # SGD step
            self.training_steps += 1

        outputs = torch.cat(outputs, dim=0)
        return outputs, total_loss, grad, grad_log_stats
Exemplo n.º 9
0
    def _step(self,
              inputs_batch,
              target_batch,
              training=False,
              average_output=False,
              chunk_batch=1):
        outputs = []
        total_loss = 0

        if training:
            self.optimizer.zero_grad()
            self.optimizer.update(self.epoch, self.training_steps)

        for i, (inputs, target) in enumerate(
                zip(inputs_batch.chunk(chunk_batch, dim=0),
                    target_batch.chunk(chunk_batch, dim=0))):
            target = target.to(self.device)
            inputs = inputs.to(self.device, dtype=self.dtype)

            mixup = None
            if training:
                self.optimizer.pre_forward()
                if self.mixup is not None or self.cutmix is not None:
                    input_mixup = CutMix() if self.cutmix else MixUp()
                    mix_val = self.mixup or self.cutmix
                    mixup_modules = [input_mixup]  # input mixup
                    mixup_modules += [
                        m for m in self.model.modules()
                        if isinstance(m, MixUp)
                    ]
                    mixup = _mixup(mixup_modules, mix_val, inputs.size(0))
                    inputs = input_mixup(inputs)

                if self.monitor > 3:
                    for mname, m in self.model.named_modules():
                        if isinstance(m, nn.BatchNorm2d):
                            if self.counter % self.monitor_freq == 0:
                                if self.monitor < 5:
                                    self.writer.add_histogram(
                                        f"BN_Weights/{mname}",
                                        m.weight.view(-1), self.counter)
                                    self.writer.add_histogram(
                                        f"BN_Biases/{mname}", m.bias.view(-1),
                                        self.counter)
                            self.writer.add_scalar(f"BN_weight_L2/{mname}",
                                                   (m.weight.data**2).sum(),
                                                   self.counter)
                            self.writer.add_scalar(f"BN_bias_L2/{mname}",
                                                   (m.bias.data**2).sum(),
                                                   self.counter)

                            if self.monitor > 5:
                                self.writer.add_scalar(
                                    f"BN_Running_Var/{mname}",
                                    (m.running_var**2).sum(), self.counter)
                                self.writer.add_scalar(
                                    f"BN_Running_Mean/{mname}",
                                    (m.running_mean**2).sum(), self.counter)

                        elif isinstance(m, AConv2d):
                            self.writer.add_scalar(f"Conv_weight_L2/{mname}",
                                                   (m.weight.data**2).sum(),
                                                   self.counter)
                            ##self.writer.add_scalar(f"Conv_bias_L2/{mname}" ,(m.bias.data**2).sum(), self.counter )

                    self.counter = self.counter + 1

                elif self.monitor > 2:
                    if self.counter % self.monitor_freq == 0:
                        for mname, m in self.model.named_modules():
                            if isinstance(m, nn.BatchNorm2d):
                                if self.monitor < 5:
                                    self.writer.add_histogram(
                                        f"BN_Weights/{mname}",
                                        m.weight.data.view(-1), self.counter)
                                    ##self.writer.add_histogram(f"BN_Biases/{mname}" ,m.bias.data.view(-1) , self.counter )

                    self.counter = self.counter + 1

            # compute output
            output = self.model(inputs)

            if mixup is not None:
                target = mixup.mix_target(target, output.size(-1))

            if average_output:
                if isinstance(output, list) or isinstance(output, tuple):
                    output = [
                        _average_duplicates(out, target)
                        if out is not None else None for out in output
                    ]
                else:
                    output = _average_duplicates(output, target)
            loss = self.criterion(output, target)
            grad = None

            if chunk_batch > 1:
                loss = loss / chunk_batch

            if isinstance(output, list) or isinstance(output, tuple):
                output = output[0]

            outputs.append(output.detach())
            total_loss += float(loss)

            if training:
                if i == 0:
                    self.optimizer.pre_backward()
                if self.grad_scale is not None:
                    loss = loss * self.grad_scale
                if self.loss_scale is not None:
                    loss = loss * self.loss_scale
                loss.backward()  # accumulate gradient

        if training:  # post gradient accumulation
            if self.loss_scale is not None:
                for p in self.model.parameters():
                    if p.grad is None:
                        continue
                    p.grad.data.div_(self.loss_scale)

            if self.grad_clip > 0:
                grad = clip_grad_norm_(self.model.parameters(), self.grad_clip)
            self.optimizer.step()  # SGD step
            self.training_steps += 1

        outputs = torch.cat(outputs, dim=0)
        return outputs, total_loss, grad
Exemplo n.º 10
0
    def _step(self,
              inputs_batch,
              target_batch,
              training=False,
              average_output=False,
              chunk_batch=1):
        outputs = []
        total_loss = 0

        if training:
            self.optimizer.zero_grad()
            self.optimizer.update(self.epoch, self.training_steps)

        for i, (inputs, target) in enumerate(
                zip(inputs_batch.chunk(chunk_batch, dim=0),
                    target_batch.chunk(chunk_batch, dim=0))):
            target = target.to(self.device)
            inputs = inputs.to(self.device, dtype=self.dtype)
            mixup = None
            if training:
                self.optimizer.pre_forward()
                if self.mixup is not None:
                    input_mixup = MixUp()
                    mixup_modules = [input_mixup]  # input mixup
                    mixup_modules += [
                        m for m in self.model.modules()
                        if isinstance(m, MixUp)
                    ]
                    mixup = _mixup(mixup_modules, self.mixup, inputs.size(0))
                    inputs = input_mixup(inputs)

            # compute output
            output = self.model(inputs)

            if mixup is not None:
                target = mixup.mix_target(target, output.size(-1))

            if average_output:
                if isinstance(output, list) or isinstance(output, tuple):
                    output = [
                        _average_duplicates(out, target)
                        if out is not None else None for out in output
                    ]
                else:
                    output = _average_duplicates(output, target)
            if isinstance(self.criterion, nn.KLDivLoss):
                emb = torch.zeros(output.shape)
                for t in range(target.shape[0]):

                    emb[t] = self.output_embed_fp32[target[t].tolist()]

                loss = self.criterion(F.log_softmax(output),
                                      F.softmax(emb.to(output)))
            else:
                loss = self.criterion(output, target)
            grad = None

            if chunk_batch > 1:
                loss = loss / chunk_batch

            if isinstance(output, list) or isinstance(output, tuple):
                output = output[0]

            outputs.append(output.detach())
            total_loss += float(loss)

            if training:
                if i == 0:
                    self.optimizer.pre_backward()
                if self.grad_scale is not None:
                    loss = loss * self.grad_scale
                if self.loss_scale is not None:
                    loss = loss * self.loss_scale
                loss.backward()  # accumulate gradient
                if self.update_only_th and not self.optimize_rounding:
                    for p in self.model.parameters():
                        if p.shape[0] == 1000 or p.dim() == 2:
                            p.grad = None
        if training:  # post gradient accumulation
            if self.loss_scale is not None:
                for p in self.model.parameters():
                    if p.grad is None:
                        continue
                    p.grad.data.div_(self.loss_scale)

            if self.grad_clip > 0:
                grad = clip_grad_norm_(self.model.parameters(), self.grad_clip)
            self.optimizer.step()  # SGD step
            self.training_steps += 1
            if self.optimize_rounding:
                sd = self.model.state_dict()
                for key in sd:
                    if 'quantize_weight' in key and 'range' in key:
                        trange = sd[key]
                        tzp = sd[key.replace('range', 'zero_point')]
                        weights_name = key.replace(
                            'quantize_weight.running_range', 'weight')
                        #import pdb; pdb.set_trace()
                        t1 = self.fp_state_dict[weights_name.replace(
                            'module.', '')]
                        t2 = sd[weights_name]
                        new_weight = quant_round_constrain(t1, t2, trange, tzp)
                        sd[weights_name] = new_weight
        outputs = torch.cat(outputs, dim=0)
        return outputs, total_loss, grad
Exemplo n.º 11
0
    def _step(self,
              inputs_batch,
              target_batch,
              training=False,
              average_output=False,
              chunk_batch=1):
        outputs = []
        total_loss = 0

        if training:
            self.optimizer.zero_grad()
            self.optimizer.update(self.epoch, self.training_steps)

        for i, (inputs, target) in enumerate(
                zip(inputs_batch.chunk(chunk_batch, dim=0),
                    target_batch.chunk(chunk_batch, dim=0))):
            target = target.to(self.device)
            inputs = inputs.to(self.device, dtype=self.dtype)

            mixup = None
            if training:
                self.optimizer.pre_forward()
                if self.mixup is not None or self.cutmix is not None:
                    input_mixup = CutMix() if self.cutmix else MixUp()
                    mix_val = self.mixup or self.cutmix
                    mixup_modules = [input_mixup]  # input mixup
                    mixup_modules += [
                        m for m in self.model.modules()
                        if isinstance(m, MixUp)
                    ]
                    mixup = _mixup(mixup_modules, mix_val, inputs.size(0))
                    inputs = input_mixup(inputs)

            # compute output
            output = self.model(inputs)

            if mixup is not None:
                target = mixup.mix_target(target, output.size(-1))

            if average_output:
                if isinstance(output, list) or isinstance(output, tuple):
                    output = [
                        _average_duplicates(out, target)
                        if out is not None else None for out in output
                    ]
                else:
                    output = _average_duplicates(output, target)
            loss = self.criterion(output, target)
            grad = None

            if chunk_batch > 1:
                loss = loss / chunk_batch

            if isinstance(output, list) or isinstance(output, tuple):
                output = output[0]

            outputs.append(output.detach())
            total_loss += float(loss)

            if training:
                if i == 0:
                    self.optimizer.pre_backward()
                if self.grad_scale is not None:
                    loss = loss * self.grad_scale
                if self.loss_scale is not None:
                    loss = loss * self.loss_scale
                loss.backward()  # accumulate gradient

        if training:  # post gradient accumulation
            if self.loss_scale is not None:
                for p in self.model.parameters():
                    if p.grad is None:
                        continue
                    p.grad.data.div_(self.loss_scale)

            if self.grad_clip > 0:
                grad = clip_grad_norm_(self.model.parameters(), self.grad_clip)
            self.optimizer.step()  # SGD step
            self.training_steps += 1

        outputs = torch.cat(outputs, dim=0)
        return outputs, total_loss, grad
    def _step(self,
              inputs_batch,
              target_batch,
              training=False,
              average_output=False,
              chunk_batch=1):
        outputs = []
        total_loss = 0

        if training:
            self.optimizer.zero_grad()
            self.optimizer.update(self.epoch, self.training_steps)

        for i, (inputs, target) in enumerate(
                zip(inputs_batch.chunk(chunk_batch, dim=0),
                    target_batch.chunk(chunk_batch, dim=0))):

            BK_hookF = {}
            for name, module in self.module_to_hook.items():
                # print("name: " + str(name))
                BK_hookF[name] = Hook(module, name, True)

            target = target.to(self.device)
            inputs = inputs.to(self.device, dtype=self.dtype)

            mixup = None
            if training:
                self.optimizer.pre_forward()
                if self.mixup is not None or self.cutmix is not None:
                    input_mixup = CutMix() if self.cutmix else MixUp()
                    mix_val = self.mixup or self.cutmix
                    mixup_modules = [input_mixup]  # input mixup
                    mixup_modules += [
                        m for m in self.model.modules()
                        if isinstance(m, MixUp)
                    ]
                    mixup = _mixup(mixup_modules, mix_val, inputs.size(0))
                    inputs = input_mixup(inputs)

            # compute output
            output = self.model(inputs)

            if mixup is not None:
                target = mixup.mix_target(target, output.size(-1))

            if average_output:
                if isinstance(output, list) or isinstance(output, tuple):
                    output = [
                        _average_duplicates(out, target)
                        if out is not None else None for out in output
                    ]
                else:
                    output = _average_duplicates(output, target)
            loss = self.criterion(output, target)
            grad = None

            if chunk_batch > 1:
                loss = loss / chunk_batch

            if isinstance(output, list) or isinstance(output, tuple):
                output = output[0]

            outputs.append(output.detach())
            total_loss += float(loss)

            if training:
                if i == 0:
                    self.optimizer.pre_backward()
                if self.grad_scale is not None:
                    loss = loss * self.grad_scale
                if self.loss_scale is not None:
                    loss = loss * self.loss_scale
                loss.backward()  # accumulate gradient

                ## moran- gather lognorm statistics

                grad_log_stats = []
                if self.input_grad_statistics:
                    # total = reduce((lambda total, item: total+item.alloc), snds)
                    curr_grad_log_stats = [{
                        'layer_name': hook.name,
                        'mean': hook.grad_log_mean,
                        'std': hook.grad_log_std
                    } for hook in BK_hookF.values()]
                    pdb.set_trace()

                for hook in BK_hookF.values():
                    hook.close()
                ## moran- end

        if training:  # post gradient accumulation
            if self.loss_scale is not None:
                for p in self.model.parameters():
                    if p.grad is None:
                        continue
                    p.grad.data.div_(self.loss_scale)

            if self.grad_clip > 0:
                grad = clip_grad_norm_(self.model.parameters(), self.grad_clip)
            self.optimizer.step()  # SGD step
            self.training_steps += 1

        outputs = torch.cat(outputs, dim=0)
        return outputs, total_loss, grad, grad_log_stats