コード例 #1
0
    def __init__(self,
                 model,
                 criterion,
                 metric_fns,
                 optimizer,
                 config,
                 device,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_fns, optimizer, config,
                         device, data_loader, valid_data_loader, lr_scheduler)

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.do_validation = self.valid_data_loader is not None
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        # Define evaluator
        self.evaluator = MNISTTester(self.model, self.criterion,
                                     self.metric_ftns, self.config,
                                     self.device, self.valid_data_loader, True)
コード例 #2
0
    def __init__(self,
                 model,
                 train_criterion,
                 metrics,
                 optimizer,
                 config,
                 data_loader,
                 parse,
                 valid_data_loader=None,
                 test_data_loader=None,
                 teacher=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 val_criterion=None,
                 mode=None,
                 entropy=False,
                 threshold=0.1):
        super().__init__(model, train_criterion, metrics, optimizer, config,
                         val_criterion, parse)
        self.config = config
        self.data_loader = data_loader
        self.mode = mode
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader

        if teacher != None:
            self.teacher = teacher.to(self.device)
        else:
            self.teacher = teacher

        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.train_loss_list: List[float] = []
        self.val_loss_list: List[float] = []
        self.test_loss_list: List[float] = []
        self.purity = (data_loader.train_dataset.train_labels == \
                       data_loader.train_dataset.train_labels_gt).sum() / len(data_loader.train_dataset)

        # Visdom visualization
        self.entropy = entropy
        if self.entropy:
            self.entro_loss = Entropy(threshold)
コード例 #3
0
 def __init__(self, model, train_criterion, metrics, optimizer, config, data_loader, parse,
              valid_data_loader=None, test_data_loader=None, teacher = None, lr_scheduler=None, len_epoch=None, val_criterion=None, mode=None, entropy=False, threshold = 0.1):
     super().__init__(model, train_criterion, metrics, optimizer, config, val_criterion, parse)
     self.config = config
     self.data_loader = data_loader
     self.mode = mode
     self.parse = parse
     
     if len_epoch is None:
         # epoch-based training
         self.len_epoch = len(self.data_loader)
     else:
         # iteration-based training
         self.data_loader = inf_loop(data_loader)
         self.len_epoch = len_epoch
     self.dynamic_train_data_loader = copy.deepcopy(data_loader)
     self.valid_data_loader = valid_data_loader
     
     self.warm_up = parse.warmup
     self.every = parse.every
     
     self.orig_data_loader = getattr(module_data, self.config['data_loader']['type'])(
         self.config['data_loader']['args']['data_dir'],
         batch_size=self.config['data_loader']['args']['batch_size'],
         shuffle=False,
         validation_split=0.1,
         num_batches=self.config['data_loader']['args']['num_batches'],
         training=True,
         num_workers=self.config['data_loader']['args']['num_workers'],
         pin_memory=self.config['data_loader']['args']['pin_memory'])
     
     if teacher != None:
         self.teacher = teacher.to(self.device)
     else:
         self.teacher = teacher
     
     self.test_data_loader = test_data_loader
     self.do_validation = self.valid_data_loader is not None
     self.do_test = self.test_data_loader is not None
     self.lr_scheduler = lr_scheduler
     self.log_step = int(np.sqrt(data_loader.batch_size))
     self.train_loss_list: List[float] = []
     self.val_loss_list: List[float] = []
     self.test_loss_list: List[float] = []
     self.purity = (data_loader.train_dataset.train_labels == \
                    data_loader.train_dataset.train_labels_gt).sum() / len(data_loader.train_dataset)
     self.teacher_idx = None
     #Visdom visualization
     
     self.entropy = entropy
     if self.entropy: self.entro_loss = Entropy(threshold)
コード例 #4
0
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 device,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.device = device
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.best_valid = 0

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

        self.file_metrics = str(self.checkpoint_dir / 'metrics.csv')
コード例 #5
0
    def __init__(self,
                 model,
                 criterion,
                 train_metric_ftns,
                 eval_metric_fns,
                 optimizer,
                 config,
                 device,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 validate_only=False):
        """

        :param model: The model to train.
        :param criterion: we ignore this value and overwrite it
        :param train_metric_ftns: The metric function names to use for training.
        :param eval_metric_fns: The metric function names to use for evaluating.
        :param optimizer: The optimizer to use.
        :param config: The configuration file for the run.
        :param device: The device to train on.
        :param data_loader: The training data loader to use.
        :param valid_data_loader: The validation data loader to use.
        :param lr_scheduler: scheduler for the learning rate.
        :param len_epoch: The amount of examples in an epoch.
        :param validate_only: use if resumed, only run validation on the last resumed checkpoint.
        """
        self.vocab = model.vocab
        self.pad_idx = self.vocab['<pad>']

        self.criterion = criterion
        super().__init__(model, self.criterion, train_metric_ftns,
                         eval_metric_fns, optimizer, config, device,
                         data_loader, valid_data_loader, lr_scheduler)

        self.question_pad_length = config['data_loader']['question_pad_length']
        self.qdmr_pad_length = config['data_loader']['qdmr_pad_length']
        self.lexicon_pad_length = config['data_loader']['lexicon_pad_length']
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch

        self.do_validation = self.valid_data_loader is not None
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.train_metric_ftns],
            writer=self.writer)

        # Define evaluator.
        self.evaluator = Seq2SeqSimpleTester(self.model, self.criterion,
                                             self.eval_metric_ftns,
                                             self.config, self.device,
                                             self.valid_data_loader, True)

        # Run validation and exit.
        if validate_only:
            val_log = self.evaluator.test()
            log = {'val_' + k: round(v, 5) for k, v in val_log.items()}
            print(log)
            exit()
コード例 #6
0
    def __init__(self,
                 model,
                 train_criterion,
                 metrics,
                 optimizer,
                 config,
                 data_loader,
                 parse,
                 valid_data_loader=None,
                 test_data_loader=None,
                 teacher=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 val_criterion=None,
                 mode=None,
                 entropy=False,
                 threshold=0.1,
                 epoch_decay_start=80,
                 n_epoch=200,
                 learning_rate=0.001):
        super().__init__(model, train_criterion, metrics, optimizer, config,
                         val_criterion, parse)
        self.config = config
        self.data_loader = data_loader
        self.mode = mode
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader

        # Specific attribute for coteaching
        self.model_1, self.model_2 = model[0].to(self.device), model[1].to(
            self.device)
        self.optimizer_1, self.optimizer_2 = optimizer[0], optimizer[1]

        if lr_scheduler is not None:
            self.lr_scheduler_1, self.lr_scheduler_2 = lr_scheduler[
                0], lr_scheduler[1]
        else:
            self.lr_scheduler_1, self.lr_scheduler_2 = None, None

        # re-initialization model
        for m in self.model_1.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')

        for m in self.model_2.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')

        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None
        #         self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.train_loss_list: List[float] = []
        self.val_loss_list: List[float] = []
        self.test_loss_list: List[float] = []
        #Visdom visualization

        self.entropy = entropy
        if self.entropy:
            self.entro_loss = Entropy(threshold)

        # Adjust learning rate and betas for Adam Optimizer

        self.epoch_decay_start = epoch_decay_start
        self.n_epoch = n_epoch
        self.learning_rate = learning_rate

        mom1, mom2 = 0.9, 0.1
        self.alpha_plan = [self.learning_rate] * self.n_epoch
        self.beta1_plan = [mom1] * self.n_epoch

        for i in range(self.epoch_decay_start, self.n_epoch):
            self.alpha_plan[i] = float(self.n_epoch - i) / (
                self.n_epoch - self.epoch_decay_start) * self.learning_rate
            self.beta1_plan[i] = mom2
コード例 #7
0
 def __init__(self, model, train_criterion, metrics, optimizer, config, data_loader, parse,
              valid_data_loader=None, test_data_loader=None, teacher = None, lr_scheduler=None, len_epoch=None, val_criterion=None, mode=None, entropy=False, threshold = 0.1):
     super().__init__(model, train_criterion, metrics, optimizer, config, val_criterion, parse)
     self.config = config
     self.data_loader = data_loader
     self.mode = mode
     self.parse = parse
     
     #####################################
     # Specific attribute for coteaching #
     #####################################
     
     self.model_1, self.model_2 = copy.deepcopy(model).to(self.device), copy.deepcopy(model).to(self.device)
     trainable_params1 = filter(lambda p: p.requires_grad, self.model_1.parameters())
     trainable_params2 = filter(lambda p: p.requires_grad, self.model_2.parameters())
     self.optimizer_1 = config.initialize('optimizer', torch.optim, [{'params': trainable_params1}])
     self.optimizer_2 = config.initialize('optimizer', torch.optim, [{'params': trainable_params2}])
     self.lr_scheduler_1 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer_1)
     self.lr_scheduler_2 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer_2)
     
     # CODE FOR RUNNING. REDUNDANT
     self.optimizer = self.optimizer_1
     self.lr_scheduler = self.lr_scheduler_1
         
     # re-initialization model
     for m in self.model_1.modules():
         if isinstance(m, nn.Conv2d):
             nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
             
     for m in self.model_2.modules():
         if isinstance(m, nn.Conv2d):
             nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
             
     ####################################
     ############# COMPLETE #############
     ####################################
     
     self.warm_up = parse.warmup
     self.every = parse.every
     
     if len_epoch is None:
         # epoch-based training
         self.len_epoch = len(self.data_loader)
         self.len_epoch_1 = self.len_epoch_2 = self.len_epoch
     else:
         # iteration-based training
         self.data_loader = inf_loop(data_loader)
         self.len_epoch = len_epoch
         self.len_epoch_1 = self.len_epoch_2 = self.len_epoch
     self.dynamic_train_data_loader_1 = copy.deepcopy(data_loader)
     self.dynamic_train_data_loader_2 = copy.deepcopy(data_loader)
     self.valid_data_loader = valid_data_loader
     
     self.orig_data_loader = getattr(module_data, self.config['data_loader']['type'])(
         self.config['data_loader']['args']['data_dir'],
         batch_size=self.config['data_loader']['args']['batch_size'],
         shuffle=False,
         validation_split=0.1,
         num_batches=self.config['data_loader']['args']['num_batches'],
         training=True,
         num_workers=self.config['data_loader']['args']['num_workers'],
         pin_memory=self.config['data_loader']['args']['pin_memory'])
     
     self.test_data_loader = test_data_loader
     self.do_validation = self.valid_data_loader is not None
     self.do_test = self.test_data_loader is not None
     self.lr_scheduler = lr_scheduler
     self.log_step = int(np.sqrt(data_loader.batch_size))
     self.train_loss_list: List[float] = []
     self.val_loss_list: List[float] = []
     self.test_loss_list: List[float] = []
     self.purity_1 = self.purity_2 = (data_loader.train_dataset.train_labels == \
                    data_loader.train_dataset.train_labels_gt).sum() / len(data_loader.train_dataset)
     self.teacher_idx_1, self.teacher_idx_2 = None, None
     #Visdom visualization
     
     self.entropy = entropy
     if self.entropy: self.entro_loss = Entropy(threshold)