示例#1
0
 def __init__(self,
              model,
              criterion,
              metric_ftns,
              optimizer: Optimizer,
              config,
              data_loader,
              valid_data_loader,
              len_epoch=None,
              log_step=2):
     super().__init__(model, criterion, metric_ftns, optimizer, config)
     self.config = config
     self.data_loader = data_loader
     self.valid_data_loader = valid_data_loader
     self.do_validation = self.valid_data_loader is not None
     self.log_step = log_step
     if len_epoch is None:
         self.data_loader_iter = data_loader
         self.len_epoch = len(self.data_loader)
     else:
         self.data_loader_iter = inf_loop(self.data_loader)
         self.valid_loader_iter = inf_loop(self.valid_data_loader)
         self.len_epoch = len_epoch
         self.valid_len_epoch = 53
     self.train_metrics = MetricTracker(
         'train_loss',
         *['train_' + m.__name__ for m in self.metric_ftns],
         writer=self.writer)
     self.valid_metrics = MetricTracker(
         'val_loss',
         *['val_' + m.__name__ for m in self.metric_ftns],
         writer=self.writer)
示例#2
0
文件: trainer.py 项目: shengliu66/ELR
    def __init__(self, model1, model2, model_ema1, model_ema2, train_criterion1, train_criterion2, metrics, optimizer1, optimizer2, config, 
                 data_loader1, data_loader2,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler1=None, lr_scheduler2=None,
                 len_epoch=None, val_criterion=None,
                 model_ema1_copy=None, model_ema2_copy=None):
        super().__init__(model1, model2, model_ema1, model_ema2, train_criterion1, train_criterion2, 
                         metrics, optimizer1, optimizer2, config, val_criterion, model_ema1_copy, model_ema2_copy)
        self.config = config.config
        self.data_loader1 = data_loader1
        self.data_loader2 = data_loader2
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader1)
        else:
            # iteration-based training
            self.data_loader1 = inf_loop(data_loader1)
            self.data_loader2 = inf_loop(data_loader2)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader

        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None
        self.lr_scheduler1 = lr_scheduler1
        self.lr_scheduler2 = lr_scheduler2
        self.log_step = int(np.sqrt(self.data_loader1.batch_size))
        self.train_loss_list: List[float] = []
        self.val_loss_list: List[float] = []
        self.test_loss_list: List[float] = []
示例#3
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config, train_iter, valid_iter, test_iter=None,
                 lr_scheduler=None, len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.train_iter, self.valid_iter, self.test_iter = train_iter, valid_iter, test_iter
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.train_iter)
        else:
            # iteration-based training
            self.data_loader = inf_loop(train_iter)
            self.len_epoch = len_epoch

        self.do_validation = self.valid_iter is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(train_iter.batch_size))

        self.train_metrics = MetricTracker('tag_loss','crf_loss','total_loss','p','r','f', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker('tag_loss', 'crf_loss','total_loss','p','r','f',*[m.__name__ for m in self.metric_ftns], writer=self.writer)

        # self.cross_entropy_weight_ = [1.0] * schema.class_tag_num[class_id]
        self.cross_entropy_weight_ = [1.0] * 9
        for i in range(1, 9):
            if i % 2 == 1:
                self.cross_entropy_weight_[i] = 1.5
        self.cross_entropy_weight_[0] = 0.1
示例#4
0
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
示例#5
0
    def __init__(self,
                 model,
                 loss,
                 metrics,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, loss, metrics, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.save_train_img_step = self.log_step * 4

        # Binarize NN output
        self.output_threshold = 0.3
示例#6
0
 def __init__(self,
              model,
              loss,
              metrics,
              optimizer,
              config,
              data_loader,
              valid_data_loader=None,
              lr_scheduler=None,
              len_epoch=None):
     super().__init__(model, loss, metrics, optimizer, config)
     self.config = config
     self.data_loader = data_loader
     if len_epoch is None:
         # epoch-based training
         self.len_epoch = len(self.data_loader)
     else:
         # iteration-based training
         self.data_loader = inf_loop(data_loader)
         self.len_epoch = len_epoch
     self.valid_data_loader = valid_data_loader
     self.do_validation = self.valid_data_loader is not None
     self.lr_scheduler = lr_scheduler
     self.log_step = int(len(data_loader) /
                         5) if len(data_loader) > 5 else 1
示例#7
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
                 data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.device = device
        self.data_loader = data_loader
        self.train_metric_ftns = self.metric_ftns[0]
        self.val_metric_ftns = self.metric_ftns[1]
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.val_imgs = self.valid_data_loader.dataset.imgs

        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.train_metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.val_metric_ftns], writer=self.writer)
        
        #training-related config
        cfg_enhance = self.config['trainer_enhance']
        
        self.mixup = cfg_enhance['mixup']
        if self.mixup == True:
            self.mixup_alpha = cfg_enhance['mixup_alpha']        
示例#8
0
    def __init__(self,
                 config,
                 logger,
                 generator,
                 discriminator,
                 encoder,
                 data_loader,
                 valid_data_loader=None):
        super().__init__(config, logger, generator, discriminator, encoder,
                         valid_data_loader)
        self.config = config
        self.data_loader = data_loader

        if self.config['trainer'].get('len_epoch', None):
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        else:
            # epoch-based training
            self.len_epoch = len(self.data_loader)

        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.log_step = int(len(data_loader) /
                            5) if len(data_loader) > 5 else 1
示例#9
0
	def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
				 valid_data_loader=None, lr_scheduler=None, len_epoch=None):
		super().__init__(model, criterion, metric_ftns, optimizer, config)
		self.config = config
		self.data_loader = data_loader
		if len_epoch is None:
			# epoch-based training
			self.len_epoch = len(self.data_loader)
		else:
			# iteration-based training
			self.data_loader = inf_loop(data_loader)
			self.len_epoch = len_epoch
		self.valid_data_loader = valid_data_loader
		self.do_validation = self.valid_data_loader is not None
		self.lr_scheduler = lr_scheduler
		self.log_step = int(np.sqrt(data_loader.batch_size))

		self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
		self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

		if hasattr(self.data_loader, 'n_valid_samples'):
			validation_samples=self.data_loader.n_valid_samples
		else:
			validation_samples=self.valid_data_loader.n_samples
		self.heatmap_sample_indices=np.sort(np.random.randint(validation_samples, size=min(16, validation_samples)))
示例#10
0
 def __init__(self, model, optimizer, config, lr_scheduler, num_words, logger, train_dataset, 
              train_dataloader, val_dataset, val_dataloader, len_epoch=None):
     super().__init__(model, optimizer, config, lr_scheduler)
     self.config = config
     self.lr_scheduler = lr_scheduler
     self.logger = logger
     self.train_dataloader = train_dataloader
     self.val_dataloader = val_dataloader
     if len_epoch is None:
       self.len_epoch = len(self.train_dataloader)
     else:
       self.train_dataloader = inf_loop(self.train_dataloader)
       self.len_epoch = len_epoch
     
     self.BCE_loss = torch.nn.BCEWithLogitsLoss() 
     self.dec_weight_loss = config["loss"]["dec_weight_loss"]
     self.loc_weight_loss= config["loss"]["loc_weight_loss"]
     self.kws_weight_loss = config["loss"]["kws_weight_loss"]
     self.g2p =  config["arch"]["args"]["g2p"]
     self.train_dataset = train_dataset
     self.val_dataset = val_dataset
     self.num_words = num_words
     self.start_BEloc_epoch = config["data_loader"]["args"]["start_BEloc_epoch"]
     self.use_BE_localiser = False
     self.clip = 2.3
     self.do_validation = True
     self.rnn_present = config["arch"]["args"]["rnn2"]
示例#11
0
    def __init__(self,
                 train_loader,
                 config,
                 valid_loader=None,
                 lr_scheduler=None,
                 max_steps=None):
        self.config = config
        # self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device = torch.device(config['device'])

        self.train_loader = inf_loop(train_loader)
        if max_steps is None:
            self.max_steps = 1000000
        else:
            self.max_steps = max_steps

        self.start_step = 1

        self.log_dir = os.path.join(
            self.config['log_dir'],
            datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))

        if (config['resume'] is None) or (not config['resume']):
            self.checkpoint_dir = os.path.join(
                self.config['checkpoint_dir'],
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
        else:
            self._resume_checkpoint(self.config['resume_path'])

        self.writer = SummaryWriter(log_dir=self.log_dir)
示例#12
0
    def __init__(
        self,
        model,
        criterion,
        metric_func,
        optimizer,
        num_epochs,
        save_period,
        config,
        data_loaders_dict,
        scheduler=None,
        device=None,
        len_epoch=None,
    ):
        super(Trainer, self).__init__(model, criterion, metric_func, optimizer,
                                      num_epochs, save_period, config, device)

        self.train_data_loader = data_loaders_dict["train"]
        self.val_data_loader = data_loaders_dict["val"]
        if len_epoch is None:
            self._len_epoch = len(self.train_data_loader)
        else:
            self.train_data_loader = inf_loop(self.train_data_loader)
            self._len_epoch = len_epoch

        self._do_validation = self.val_data_loader is not None
        self._scheduler = scheduler
示例#13
0
    def __init__(self,
                 model,
                 loss,
                 metrics,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, loss, metrics, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        # initialize weights for loss function
        self.weights = torch.from_numpy(self.data_loader.dataset.weights)
        self.weights = self.weights.float()
        if torch.cuda.is_available():
            self.weights = self.weights.cuda()
示例#14
0
 def __init__(self,
              model,
              loss,
              evaluator: SegEvaluator,
              optimizer,
              config,
              data_loader,
              valid_data_loader=None,
              lr_scheduler=None,
              len_epoch=None):
     super().__init__(model, loss, evaluator, optimizer, config)
     self.config = config
     self.dataset = config['data_loader']['args']['dataset']
     # self.ncalss = nclass
     self.data_loader = data_loader
     if len_epoch is None:
         # epoch-based training
         self.len_epoch = len(self.data_loader)
     else:
         # iteration-based training
         self.data_loader = inf_loop(data_loader)
         self.len_epoch = len_epoch
     self.valid_data_loader = valid_data_loader
     self.do_validation = self.valid_data_loader is not None
     self.lr_scheduler = lr_scheduler
     self.log_step = int(np.sqrt(len(data_loader)) * 2)
示例#15
0
    def __init__(self,
                 model,
                 train_criterion,
                 metrics,
                 optimizer,
                 config,
                 data_loader,
                 parse,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 val_criterion=None):
        super().__init__(model, train_criterion, metrics, optimizer, config,
                         val_criterion, parse)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader

        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.train_loss_list: List[float] = []
        self.val_loss_list: List[float] = []
        self.test_loss_list: List[float] = []
示例#16
0
    def __init__(self,
                 model,
                 loss,
                 metrics,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, loss, metrics, optimizer, config)

        if len_epoch is None:
            # Epoch-based training
            self._data_loader = data_loader
            self._len_epoch = len(self._data_loader)
        else:
            # Iteration-based training
            self._data_loader = inf_loop(data_loader)
            self._len_epoch = len_epoch

        self._valid_data_loader = valid_data_loader
        self._do_validation = self._valid_data_loader is not None

        self._lr_scheduler = lr_scheduler
        self._log_step = int(math.sqrt(self._len_epoch))

        with open(
                Path(config['data_loader']['args']['data_dir']) /
                'sales_mean.pkl', 'rb') as file:
            self._mean = pickle.load(file)
        with open(
                Path(config['data_loader']['args']['data_dir']) /
                'sales_std.pkl', 'rb') as file:
            self._std = pickle.load(file)
示例#17
0
    def __init__(self,
                 config,
                 logger,
                 generator,
                 discriminator,
                 gif_generator,
                 data_loader,
                 valid_data_loader=None):
        super().__init__(config, logger, generator, discriminator,
                         gif_generator)
        self.config = config
        self.data_loader = data_loader

        if self.config['trainer'].get('len_epoch', None):
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        else:
            # epoch-based training
            self.len_epoch = len(self.data_loader)

        self.clip_value = discriminator["config"]["clip_value"]

        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.log_step = int(len(data_loader) /
                            5) if len(data_loader) > 5 else 1

        self.z_size = self.generator['config']['arch']['args']['input_size']
        self.output_size = self.generator['config']['arch']['args'][
            'output_size']
示例#18
0
    def __init__(self,
                 model,
                 loss,
                 metrics,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 use_apex=True):

        super().__init__(model, loss, metrics, optimizer, config, use_apex)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.use_apex = use_apex
示例#19
0
    def __init__(
        self,
        model: Module,
        loss_fn: Callable,
        loss_args: Dict[str, Any],
        metric_fns: List[Callable],
        metric_args: List[Dict[str, Any]],
        optimizer: Optimizer,
        config: ConfigParser,
        data_loader: DataLoader,
        valid_data_loader: Optional[DataLoader] = None,
        lr_scheduler: Optional[Any] = None,
        len_epoch: Optional[int] = None,
    ) -> None:

        super().__init__(model, loss_fn, loss_args, metric_fns, metric_args, optimizer, config)
        self.config: ConfigParser = config  # TODO: isn't this in BaseTrainer already?
        self.data_loader: DataLoader = data_loader

        # Epoch-based training.
        if len_epoch is None:
            self.len_epoch: int = len(self.data_loader)

        # Iteration-based training.
        else:
            self.data_loader: DataLoader = inf_loop(data_loader)
            self.len_epoch: int = len_epoch

        self.valid_data_loader: Optional[DataLoader] = valid_data_loader
        self.do_validation: bool = self.valid_data_loader is not None
        self.lr_scheduler: Optional[Any] = lr_scheduler
        self.log_step: int = int(np.sqrt(data_loader.batch_size))
示例#20
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
                 valid_data_loader=None, lr_scheduler=None, len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.n_batches = data_loader.n_samples / data_loader.batch_size

        self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

        if self.do_validation:
            keys_val = ['val_' + k for k in self.keys]
            for key in self.keys + keys_val:
                self.log[key] = []

        cfg_loss = config['trainer']['loss']
        self.alpha = cfg_loss['alpha']
        self.epsilon = cfg_loss['epsilon']
    def __init__(self,
                 models,
                 criterion,
                 metric_ftns,
                 optimizers,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_schedulers=None,
                 len_epoch=None):
        super().__init__(models, criterion, metric_ftns, optimizers, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_schedulers = lr_schedulers
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            *['loss_' + str(i) for i in range(self.n_ensembles)],
            *[
                m.__name__ + '_' + str(i) for m in self.metric_ftns
                for i in range(self.n_ensembles)
            ],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            *['loss_' + str(i) for i in range(self.n_ensembles)],
            *[
                m.__name__ + '_' + str(i) for m in self.metric_ftns
                for i in range(self.n_ensembles)
            ],
            writer=self.writer)

        if self.do_validation:
            keys_val = ['val_' + k for k in self.keys]
            for key in self.keys + keys_val:
                self.log[key] = []

        cfg_loss = config['trainer']['loss']
        self.type_in = cfg_loss['type_in']
        self.alpha = cfg_loss['alpha']
        self.loss_type = cfg_loss['loss_type']
        self.censor_R = cfg_loss['censor_R']
        self.soften = cfg_loss['soften']
        self.lambda_in = cfg_loss['lambda_in']
        self.sigma_in = cfg_loss['sigma_in']

        self._create_loss()
示例#22
0
    def __init__(
        self,
        model,
        criterion,
        metric_func,
        optimizer,
        num_epochs,
        save_period,
        config,
        data_loaders_dict,
        scheduler=None,
        device=None,
        len_epoch=None,
        dataset_name_base="",
        batch_multiplier=1,
        logger=None,
        processed_batch=0,
        adjust_lr_callback=None,
        print_after_batch_num=10,
    ):
        super(Trainer, self).__init__(
            model,
            criterion,
            metric_func,
            optimizer,
            num_epochs,
            save_period,
            config,
            device,
            dataset_name_base,
            batch_multiplier,
            logger,
        )

        self.train_data_loader = data_loaders_dict["train"]
        self.val_data_loader = data_loaders_dict["val"]

        self.num_train_imgs = len(self.train_data_loader.dataset)
        self.num_val_imgs = len(self.val_data_loader.dataset)

        self.processed_batch = processed_batch
        self.adjust_lr_callback = adjust_lr_callback

        if len_epoch is None:
            self._len_epoch = len(self.train_data_loader)
        else:
            self.train_data_loader = inf_loop(self.train_data_loader)
            self._len_epoch = len_epoch

        self._do_validation = self.val_data_loader is not None
        self._scheduler = scheduler

        self._print_after_batch_num = print_after_batch_num

        self._stat_keys = ["loss", "loss_x", "loss_y", "loss_w", "loss_h", "loss_conf", "loss_cls"]
示例#23
0
    def __init__(self,
                 model,
                 loss_fn_class,
                 loss_fn_domain,
                 metric_ftns,
                 optimizer,
                 config,
                 device,
                 data_loader_source,
                 valid_data_loader_source=None,
                 data_loader_target=None,
                 valid_data_loader_target=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, metric_ftns, optimizer, config)
        self.config = config
        self.device = device
        self.loss_fn_class = loss_fn_class
        self.loss_fn_domain = loss_fn_domain
        self.data_loader_source = data_loader_source
        self.valid_data_loader_source = valid_data_loader_source
        self.data_loader_target = data_loader_target
        self.valid_data_loader_target = valid_data_loader_target
        self.model.to(self.device)

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = min(len(self.data_loader_source),
                                 len(self.data_loader_target))
        else:
            # FIXME: implement source/target style training or remove this feature
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        # FIXME: handle validation round
        self.valid_data_loader = valid_data_loader_source
        self.do_validation = self.valid_data_loader is not None

        self.lr_scheduler = lr_scheduler
        self.log_step = 64

        self.train_metrics = MetricTracker(
            'loss',
            'class_loss',
            'domain_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            'class_loss',
            'domain_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
示例#24
0
    def __init__(self,
                 s_model,
                 t_model,
                 epoch,
                 criterion,
                 metrics,
                 optimizer,
                 device,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 checkpoint=None,
                 sts=[]):  # sts=[stop, st_empty, save_dir]
        super().__init__(s_model,
                         criterion,
                         metrics,
                         optimizer,
                         epoch,
                         checkpoint,
                         save_dir=sts[2],
                         st_stop=sts[0])
        self.scaler = GradScaler()
        self.device = device
        self.s_model = self.model
        self.s_model = self.s_model.to(device)
        self.t_model = t_model
        self.t_model = self.t_model.to(device)
        self.kd_criterion = nn.KLDivLoss(size_average=False)
        self.data_loader = data_loader
        if len_epoch is None:
            self.len_epoch = len(self.data_loader)
        else:
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler

        self.st_empty = sts[1]
        self.st_container = self.st_empty.beta_container()
        self.lossChart = self.st_container.line_chart()
        self.processBar = self.st_container.progress(0)
        self.epochResult = self.st_container.table()
        self.train_idx = 0

        self.log_step = 100
        self.train_metrics = MetricTracker('loss',
                                           *[m.__name__ for m in self.metrics],
                                           writer=None)
        self.valid_metrics = MetricTracker('loss',
                                           *[m.__name__ for m in self.metrics],
                                           writer=None)
示例#25
0
    def __init__(self, generator, discriminator, config):
        super().__init__(generator, discriminator, config)
        self.config = config

        if self.config['trainer'].get('len_epoch', None):
            # iteration-based training
            self.generator['data_loader'] = inf_loop(
                self.generator['data_loader'])
            self.discriminator['data_loader'] = inf_loop(
                self.discriminator['data_loader'])
            self.len_epoch = len_epoch
        else:
            self.len_epoch = min(len(self.generator['data_loader']),
                                 len(self.discriminator['data_loader']))

        len_gen_data_loader = len(self.generator['data_loader'])
        self.generator['log_step'] = int(len_gen_data_loader /
                                         5) if len_gen_data_loader > 5 else 1

        len_dis_data_loader = len(self.discriminator['data_loader'])
        self.discriminator['log_step'] = int(
            len_dis_data_loader / 5) if len_dis_data_loader > 5 else 1
示例#26
0
    def setup_loader(self, train_data_loader, valid_data_loader,
                     test_data_loader):
        self.train_data_loader = train_data_loader
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader

        if train_data_loader:
            if 'len_epoch' in self.config['trainer']:
                # iteration-based training
                self.train_data_loader = inf_loop(train_data_loader)
                self.len_epoch = self.config['trainer']['len_epoch']
            else:
                self.len_epoch = len(train_data_loader)
示例#27
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
                 valid_data_loader=None, lr_scheduler=None, len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config

        self.distill = config._config.get('distill', False)
        
        # add_extra_info will return info about individual experts. This is crucial for individual loss. If this is false, we can only get a final mean logits.
        self.add_extra_info = config._config.get('add_extra_info', False)

        if self.distill:
            print("** Distill is on, please double check distill_checkpoint in config **")
            self.teacher_model = config.init_obj('distill_arch', module_arch)
            teacher_checkpoint = torch.load(config['distill_checkpoint'], map_location="cpu")

            self.teacher_model = self.teacher_model.to(self.device)

            teacher_state_dict = teacher_checkpoint["state_dict"]

            rename_parallel_state_dict(teacher_state_dict)
            
            if len(self.device_ids) > 1:
                print("Using multiple GPUs for teacher model")
                self.teacher_model = torch.nn.DataParallel(self.teacher_model, device_ids=self.device_ids)
                load_state_dict(self.teacher_model, {"module." + k: v for k, v in teacher_state_dict.items()}, no_ignore=True)
            else:
                load_state_dict(self.teacher_model, teacher_state_dict, no_ignore=True)

        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch

        if use_fp16:
            self.logger.warn("FP16 is enabled. This option should be used with caution unless you make sure it's working and we do not provide guarantee.")
            from torch.cuda.amp import GradScaler
            self.scaler = GradScaler()
        else:
            self.scaler = None

        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
示例#28
0
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 overfit_single_batch=False):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader if not overfit_single_batch else None
        self.test_data_loader = test_data_loader if not overfit_single_batch else None
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.overfit_single_batch = overfit_single_batch

        # -------------------------------------------------
        # add flexibility to allow no metric in config.json
        self.log_loss = ['loss', 'nll', 'kl']
        if self.metric_ftns is None:
            self.train_metrics = MetricTracker(*self.log_loss,
                                               writer=self.writer)
            self.valid_metrics = MetricTracker(*self.log_loss,
                                               writer=self.writer)
        # -------------------------------------------------
        else:
            self.train_metrics = MetricTracker(
                *self.log_loss,
                *[m.__name__ for m in self.metric_ftns],
                writer=self.writer)
            self.valid_metrics = MetricTracker(
                *self.log_loss,
                *[m.__name__ for m in self.metric_ftns],
                writer=self.writer)
            self.test_metrics = MetricTracker(
                *[m.__name__ for m in self.metric_ftns], writer=self.writer)
示例#29
0
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, config)

        ##############
        # 要在这里调整图注意力,不然进去之后封装就会变得复杂
        # 传入这些矩阵的时候,因为一些权重没有放在GPU上,需要额外to一下
        self.model.g_att_v.init_params(data_loader.dataset.edge_matrix_v,
                                       data_loader.dataset.affectiveness_v,
                                       data_loader.dataset.embedding_concept_v,
                                       self.model.device)
        self.model.g_att_a.init_params(data_loader.dataset.edge_matrix_a,
                                       data_loader.dataset.affectiveness_a,
                                       data_loader.dataset.embedding_concept_a,
                                       self.model.device)
        self.model.g_att_t.init_params(data_loader.dataset.edge_matrix_t,
                                       data_loader.dataset.affectiveness_t,
                                       data_loader.dataset.embedding_concept_t,
                                       self.model.device)
        ##############

        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = config.init_obj('lr_scheduler',
                                            torch.optim.lr_scheduler,
                                            self.optimizer)
        self.log_step = 200
        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
示例#30
0
    def __init__(self, model: DepthwiseStudent, criterions, metric_ftns, optimizer, config, train_data_loader,
                 valid_data_loader=None, lr_scheduler=None, weight_scheduler=None):
        super().__init__(model, None, metric_ftns, optimizer, config)
        self.config = config
        self.train_data_loader = train_data_loader
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_validation_interval = self.config['trainer']['do_validation_interval']
        self.lr_scheduler = lr_scheduler
        self.weight_scheduler = weight_scheduler
        self.log_step = config['trainer']['log_step']
        self.importance_log_interval = self.config['trainer']['importance_log_interval']
        if "len_epoch" in self.config['trainer']:
            # iteration-based training
            self.train_data_loader = inf_loop(train_data_loader)
            self.len_epoch = self.config['trainer']['len_epoch']
        else:
            # epoch-based training
            self.len_epoch = len(self.train_data_loader)

        # Metrics
        # Train
        self.train_metrics = MetricTracker('loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss',
                                           *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.train_iou_metrics = CityscapesMetricTracker(writer=self.writer)
        self.train_teacher_iou_metrics = CityscapesMetricTracker(writer=self.writer)
        # Valid
        self.valid_metrics = MetricTracker('loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss',
                                           *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_iou_metrics = CityscapesMetricTracker(writer=self.writer)
        # Test
        self.test_metrics = MetricTracker('loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss',
                                          *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.test_iou_metrics = CityscapesMetricTracker(writer=self.writer)

        # Tracker for early stop if val miou doesn't increase
        self.val_iou_tracker = EarlyStopTracker('best', 'max', 0.01, 'rel')
        # Tracker for importance of filter in layer
        self.importance_tracker = ImportanceFilterTracker(writer=self.writer)

        # Only used list of criterions and remove the unused property
        self.criterions = criterions
        self.criterions = nn.ModuleList(self.criterions).to(self.device)
        if isinstance(self.model, nn.DataParallel):
            self.criterions = nn.DataParallel(self.criterions)
        del self.criterion

        # Resume checkpoint if path is available in config
        if 'resume_path' in self.config['trainer']:
            self.resume(self.config['trainer']['resume_path'])