예제 #1
0
    def __init__(self, path, net, run_config, init=True, measure_latency=None, no_gpu=False):

        self.path = path
        self.net = net
        self.run_config = run_config

        self.best_acc = 0.0
        self.start_epoch = 0

        os.makedirs(self.path, exist_ok=True)

        self.device = xm.xla_device()
        self.net = xmp.MpModelWrapper(self.net).to(self.device)

        # initialize model (default)
        if init:
            init_models(self.net, run_config.model_init)

        # net info
        net_info = get_net_info(self.net, self.run_config.data_provider.data_shape, measure_latency, True)
        with open('%s/net_info.txt' % self.path, 'w') as fout:
            fout.write(json.dumps(net_info, indent=4) + '\n')
            # noinspection PyBroadException
            try:
                fout.write(self.network.module_str + '\n')
            except Exception:
                pass
            fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
            fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
            fout.write('%s\n' % self.network)

        # criterion
        if isinstance(self.run_config.mixup_alpha, float):
            self.train_criterion = cross_entropy_loss_with_soft_target
        elif self.run_config.label_smoothing > 0:
            self.train_criterion = lambda pred, target: \
                cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
        else:
            self.train_criterion = nn.CrossEntropyLoss()
        self.test_criterion = nn.CrossEntropyLoss()

        # optimizer
        if self.run_config.no_decay_keys:
            keys = self.run_config.no_decay_keys.split('#')
            net_params = [
                self.net.get_parameters(keys, mode='exclude'),  # parameters with weight decay
                self.net.get_parameters(keys, mode='include'),  # parameters without weight decay
            ]
        else:
            # noinspection PyBroadException
            try:
                net_params = self.network.weight_parameters()
            except Exception:
                net_params = []
                for param in self.network.parameters():
                    if param.requires_grad:
                        net_params.append(param)
        self.optimizer = self.run_config.build_optimizer(net_params)
예제 #2
0
    def __init__(self,
                 path,
                 net,
                 run_config,
                 init=True,
                 measure_latency=None,
                 no_gpu=False):
        self.path = path
        self.net = net
        self.run_config = run_config

        self.best_acc = 0
        self.start_epoch = 0
        self.grad_dict = OrderedDict()

        os.makedirs(self.path, exist_ok=True)

        # move network to GPU if available
        if torch.cuda.is_available() and (not no_gpu):
            self.device = torch.device('cuda:0')
            self.net = self.net.to(self.device)
            cudnn.benchmark = True
        else:
            self.device = torch.device('cpu')
        # initialize model (default)
        if init:
            init_models(run_config.model_init)

        # model info
        net_info = get_net_info(self.net,
                                self.run_config.data_provider.data_shape,
                                measure_latency, False)
        with open('%s/net_info.txt' % self.path, 'w') as fout:
            fout.write(json.dumps(net_info, indent=4) + '\n')
            # noinspection PyBroadException
            try:
                fout.write(self.network.module_str + '\n')
            except Exception:
                pass
            try:
                fout.write(
                    '%s\n' %
                    self.run_config.data_provider.train.dataset.transform)
                fout.write(
                    '%s\n' %
                    self.run_config.data_provider.test.dataset.transform)
            except:
                fout.write('%s\n' % self.run_config.data_provider.train.
                           dataset.dataset.transform)
                fout.write(
                    '%s\n' %
                    self.run_config.data_provider.test.dataset.transform)
            fout.write('%s\n' % self.network)

        # criterion
        if isinstance(self.run_config.mixup_alpha, float):
            self.train_criterion = cross_entropy_loss_with_soft_target
        elif self.run_config.label_smoothing > 0:
            self.train_criterion = \
                lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
        else:
            self.train_criterion = nn.CrossEntropyLoss()
        self.test_criterion = nn.CrossEntropyLoss()

        # optimizer
        if self.run_config.no_decay_keys:
            keys = self.run_config.no_decay_keys.split('#')
            net_params = [
                self.network.get_parameters(
                    keys, mode='exclude'),  # parameters with weight decay
                self.network.get_parameters(
                    keys, mode='include'),  # parameters without weight decay
            ]
        else:
            # noinspection PyBroadException
            try:
                net_params = self.network.weight_parameters()
            except Exception:
                net_params = []
                for param in self.network.parameters():
                    if param.requires_grad:
                        net_params.append(param)
        self.optimizer = self.run_config.build_optimizer(net_params)

        self.net = torch.nn.DataParallel(self.net)
예제 #3
0
    def __init__(self,
                 path,
                 net,
                 run_config,
                 hvd_compression,
                 backward_steps=1,
                 is_root=False,
                 init=True):
        import horovod.torch as hvd

        self.path = path
        self.net = net
        self.run_config = run_config
        self.is_root = is_root

        self.best_acc = 0.0
        self.start_epoch = 0

        os.makedirs(self.path, exist_ok=True)

        self.net.cuda()
        cudnn.benchmark = True
        if init and self.is_root:
            init_models(self.net, self.run_config.model_init)
        if self.is_root:
            # print model info
            net_info = get_net_info(self.net,
                                    self.run_config.data_provider.data_shape)
            with open('%s/net_info.txt' % self.path, 'w') as fout:
                fout.write(json.dumps(net_info, indent=4) + '\n')
                try:
                    fout.write(self.net.module_str + '\n')
                except Exception:
                    fout.write('%s do not support `module_str`' %
                               type(self.net))
                fout.write(
                    '%s\n' %
                    self.run_config.data_provider.train.dataset.transform)
                fout.write(
                    '%s\n' %
                    self.run_config.data_provider.test.dataset.transform)
                fout.write('%s\n' % self.net)

        # criterion
        if isinstance(self.run_config.mixup_alpha, float):
            self.train_criterion = cross_entropy_loss_with_soft_target
        elif self.run_config.label_smoothing > 0:
            self.train_criterion = lambda pred, target: \
             cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
        else:
            self.train_criterion = nn.CrossEntropyLoss()
        self.test_criterion = nn.CrossEntropyLoss()

        # optimizer
        if self.run_config.no_decay_keys:
            keys = self.run_config.no_decay_keys.split('#')
            net_params = [
                self.net.get_parameters(
                    keys, mode='exclude'),  # parameters with weight decay
                self.net.get_parameters(
                    keys, mode='include'),  # parameters without weight decay
            ]
        else:
            # noinspection PyBroadException
            try:
                net_params = self.network.weight_parameters()
            except Exception:
                net_params = []
                for param in self.network.parameters():
                    if param.requires_grad:
                        net_params.append(param)
        self.optimizer = self.run_config.build_optimizer(net_params)
        self.optimizer = hvd.DistributedOptimizer(
            self.optimizer,
            named_parameters=self.net.named_parameters(),
            compression=hvd_compression,
            backward_passes_per_step=backward_steps,
        )
예제 #4
0
    def __init__(self,
                 main_branch,
                 in_channels,
                 out_channels,
                 expand=1.0,
                 kernel_size=3,
                 act_func='relu',
                 n_groups=2,
                 downsample_ratio=2,
                 upsample_type='bilinear',
                 stride=1):
        super(LiteResidualModule, self).__init__()

        self.main_branch = main_branch

        self.lite_residual_config = {
            'in_channels': in_channels,
            'out_channels': out_channels,
            'expand': expand,
            'kernel_size': kernel_size,
            'act_func': act_func,
            'n_groups': n_groups,
            'downsample_ratio': downsample_ratio,
            'upsample_type': upsample_type,
            'stride': stride,
        }

        kernel_size = 1 if downsample_ratio is None else kernel_size

        padding = get_same_padding(kernel_size)
        if downsample_ratio is None:
            pooling = MyGlobalAvgPool2d()
        else:
            pooling = nn.AvgPool2d(downsample_ratio, downsample_ratio, 0)
        num_mid = make_divisible(int(in_channels * expand),
                                 divisor=MyNetwork.CHANNEL_DIVISIBLE)
        self.lite_residual = nn.Sequential(
            OrderedDict({
                'pooling':
                pooling,
                'conv1':
                nn.Conv2d(in_channels,
                          num_mid,
                          kernel_size,
                          stride,
                          padding,
                          groups=n_groups,
                          bias=False),
                'bn1':
                nn.BatchNorm2d(num_mid),
                'act':
                build_activation(act_func),
                'conv2':
                nn.Conv2d(num_mid, out_channels, 1, 1, 0, bias=False),
                'final_bn':
                nn.BatchNorm2d(out_channels),
            }))

        # initialize
        init_models(self.lite_residual)
        self.lite_residual.final_bn.weight.data.zero_()
예제 #5
0
         args.lite_residual_groups,
     )
     # replace bn layers with gn layers
     replace_bn_with_gn(net, gn_channel_per_group=8)
     # load pretrained model
     init_file = download_url(
         'https://hanlab.mit.edu/projects/tinyml/tinyTL/files/'
         'proxylessnas_mobile+lite_residual@imagenet@ws+gn',
         model_dir='~/.tinytl/')
     net.load_state_dict(
         torch.load(init_file, map_location='cpu')['state_dict'])
     net.classifier = LinearLayer(net.classifier.in_features,
                                  run_config.data_provider.n_classes,
                                  dropout_rate=args.dropout)
     classification_head.append(net.classifier)
     init_models(classification_head)
 else:
     if args.net_path is not None:
         net_config_path = os.path.join(args.net_path, 'net.config')
         init_path = os.path.join(args.net_path, 'init')
     else:
         base_url = 'https://hanlab.mit.edu/projects/tinyml/tinyTL/files/specialized/%s/' % args.dataset
         net_config_path = download_url(
             base_url + 'net.config',
             model_dir='~/.tinytl/specialized/%s' % args.dataset)
         init_path = download_url(base_url + 'init',
                                  model_dir='~/.tinytl/specialized/%s' %
                                  args.dataset)
     net_config = json.load(open(net_config_path, 'r'))
     net = build_network_from_config(net_config)
     net.classifier = LinearLayer(net.classifier.in_features,