Ejemplo n.º 1
0
    def __init__(self,
                 data_dir,
                 batch_size=4,
                 download=False,
                 shuffle=True,
                 validation_split=0.1,
                 num_workers=0,
                 pin_memory=True,
                 flavor=100,
                 training=True):
        if training:
            self.preprocess = CIFAR_TRAIN_PRE
        else:
            self.preprocess = CIFAR_TEST_PRE

        if flavor == 10:
            dataset = datasets.CIFAR10(
                data_dir,
                train=training,
                download=download if is_master() else False,
                transform=self.preprocess)
        else:
            dataset = datasets.CIFAR100(
                data_dir,
                train=training,
                download=download if is_master() else False,
                transform=self.preprocess)

        if training:
            len_set = len(dataset)
            if isinstance(validation_split, int):
                assert validation_split >= 0, "validation_split can not be negative"
                assert len_set > validation_split, "validation_split is bigger than data set size"
                len_valid = validation_split
            else:
                len_valid = int(validation_split * len_set)
            len_train = len_set - len_valid
            train_set, valid_set = random_split(dataset,
                                                [len_train, len_valid])
            self.train_sampler = DistributedSampler(
                train_set) if dist.is_initialized() else None
            self.train_loader = DataLoader(
                train_set,
                batch_size=batch_size,
                shuffle=(shuffle if self.train_sampler is None else False),
                sampler=self.train_sampler,
                num_workers=num_workers,
                pin_memory=pin_memory)
            self.valid_loader = DataLoader(valid_set,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           num_workers=num_workers,
                                           pin_memory=pin_memory)
        else:
            self.test_loader = DataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=num_workers,
                                          pin_memory=pin_memory)
Ejemplo n.º 2
0
def restart_if_new_version(original_content):
    new_content = get_own_file_content()
    log('Lengths %s %s' % (len(original_content), len(new_content)))
    log('is master %s ' % utils.is_master())
    # Restart if the script got updated.
    if new_content != original_content:
        log('Restarting tools/internal_test.py, content changed')
        os.execv(sys.argv[0], sys.argv)
Ejemplo n.º 3
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config,train_sampler=None):
        self.config = config
        if dist.is_initialized():
            logger_name="{}{}".format(__name__,dist.get_rank())
        else:
            logger_name=__name__
        self.logger = get_logger(name=logger_name, log_dir=config.log_dir, verbosity=config['trainer']['verbosity'])

        self.model = model
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')
        self.train_sampler=train_sampler

        # configuration to monitor model performance and save best
        if not is_master() or self.monitor == 'off':
            self.mnt_mode = 'off'
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            if self.early_stop <= 0:
                self.early_stop = inf

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance                
        if is_master():
            self.writer = TensorboardWriter(config, cfg_trainer['tensorboard'])
        else:
            self.writer = TensorboardWriter(config, False)

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
Ejemplo n.º 4
0
    def train(self):
        """
        Full training logic
        """
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):
            if dist.is_initialized():
                self.train_sampler.set_epoch(epoch-1)
            result = self._train_epoch(epoch)

            # save logged informations into log dict
            log = {'epoch': epoch}
            log.update(result)

            # print logged informations to the screen
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning("Warning: Metric '{}' is not found. "
                                        "Model performance monitoring is disabled.".format(self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1
                    self.logger.info("Monitor metric did\'t improve for epoch#: {}".format(epoch))

                if not_improved_count > self.early_stop:
                    self.logger.info("Monitor metric didn\'t improve for {} epochs Training stops.".format(self.early_stop))
                    if dist.is_initialized():
                        dist.destroy_process_group()
                    exit(0)
            if is_master() and (epoch % self.save_period == 0 or best):
                state=self._generate_model_state(epoch)
                if epoch % self.save_period == 0:
                    self._save_checkpoint(epoch, state)
                if best:
                    self._save_best(state)
            if dist.is_initialized():
                self.logger.debug("Barrier after saving")
                dist.barrier()
Ejemplo n.º 5
0
    def __init__(self, VGG_type='A', batch_norm=False, bit_width=8, num_classes=1000, pretrained_model=None):
        super(QuantVGG, self).__init__()
        self.logger = get_logger(name=("{}{}".format(__name__, dist.get_rank()) if dist.is_initialized() else __name__))
        self.inp_quant = qnn.QuantIdentity(bit_width=bit_width, act_quant=INPUT_QUANTIZER, return_quant_tensor=RETURN_QUANT_TENSOR)
        self.features = make_layers(cfgs[VGG_type], batch_norm, bit_width)
        self.avgpool = qnn.QuantAdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            qnn.QuantLinear(512 * 7 * 7, 4096,
                            bias=True,
                            bias_quant=BIAS_QUANTIZER,
                            weight_quant=WEIGHT_QUANTIZER,
                            weight_bit_width=bit_width,
                            weight_scaling_min_val=SCALING_MIN_VAL,
                            return_quant_tensor=RETURN_QUANT_TENSOR),
            qnn.QuantReLU(bit_width=bit_width,
                          act_quant=ACT_QUANTIZER,
                          return_quant_tensor=RETURN_QUANT_TENSOR),
            qnn.QuantDropout(),
            qnn.QuantLinear(4096, 4096,
                            bias=True,
                            bias_quant=BIAS_QUANTIZER,
                            weight_quant=WEIGHT_QUANTIZER,
                            weight_bit_width=bit_width,
                            weight_scaling_min_val=SCALING_MIN_VAL,
                            return_quant_tensor=RETURN_QUANT_TENSOR),
            qnn.QuantReLU(bit_width=bit_width,
                          act_quant=ACT_QUANTIZER,
                          return_quant_tensor=RETURN_QUANT_TENSOR),
            nn.Dropout(),
            qnn.QuantLinear(4096, num_classes,
                            bias=False,
                            weight_quant=WEIGHT_QUANTIZER,
                            weight_scaling_min_val=SCALING_MIN_VAL,
                            weight_bit_width=bit_width,
                            return_quant_tensor=False),
        )
        self.classifier[0].cache_inference_quant_bias = True
        self.classifier[3].cache_inference_quant_bias = True
        self.classifier[6].cache_inference_quant_bias = True

        if is_master():
            print_config(self.logger)

        if pretrained_model == None:
            self._initialize_weights()
        else:
            pre_model = None
            if pretrained_model == 'pytorch':
                self.logger.info(
                    "Initializing with pretrained model from PyTorch")
                # use pytorch's pretrained model
                pre_model = models.vgg16(pretrained=True)
            else:
                pre_model = VGG_net(VGG_type=VGG_type, batch_norm=batch_norm, num_classes=num_classes)
                loaded_model = torch.load(pretrained_model)['state_dict']
                # check if model was trained using DataParallel, keys() return 'odict_keys' which does not support indexing
                if next(iter(loaded_model.keys())).startswith('module'):
                    # if model is trained w/ DataParallel it's warraped under module
                    pre_model = torch.nn.DataParallel(pre_model)
                    pre_model.load_state_dict(loaded_model)
                    unwrapped_sd = pre_model.module.state_dict()
                    pre_model = VGG_net(VGG_type=VGG_type, batch_norm=batch_norm, num_classes=num_classes)
                    pre_model.load_state_dict(unwrapped_sd)
                else:
                    pre_model.load_state_dict(loaded_model)
            self._initialize_custom_weights(pre_model)
        self.logger.info("Initialization Done")