示例#1
0
class FoodIngredients(Network):
    def __init__(self,
                 model_name='DenseNet',
                 model_type='food',
                 lr=0.02,
                 optimizer_name='Adam',
                 criterion1=nn.CrossEntropyLoss(),
                 criterion2=nn.BCEWithLogitsLoss(),
                 dropout_p=0.45,
                 pretrained=True,
                 device=None,
                 best_accuracy=0.,
                 best_validation_loss=None,
                 best_model_file='best_model.pth',
                 head1={
                     'num_outputs': 10,
                     'layers': [],
                     'model_type': 'classifier'
                 },
                 head2={
                     'num_outputs': 10,
                     'layers': [],
                     'model_type': 'multi_label_classifier'
                 },
                 class_names=[],
                 num_classes=None,
                 ingredient_names=[],
                 num_ingredients=None,
                 add_extra=True,
                 set_params=True,
                 set_head=True):

        super().__init__(device=device)

        self.set_transfer_model(model_name,
                                pretrained=pretrained,
                                add_extra=add_extra,
                                dropout_p=dropout_p)

        if set_head:
            self.set_model_head(model_name=model_name,
                                head1=head1,
                                head2=head2,
                                dropout_p=dropout_p,
                                criterion1=criterion1,
                                criterion2=criterion2,
                                device=device)
        if set_params:
            self.set_model_params(
                optimizer_name=optimizer_name,
                lr=lr,
                dropout_p=dropout_p,
                model_name=model_name,
                model_type=model_type,
                best_accuracy=best_accuracy,
                best_validation_loss=best_validation_loss,
                best_model_file=best_model_file,
                class_names=class_names,
                num_classes=num_classes,
                ingredient_names=ingredient_names,
                num_ingredients=num_ingredients,
            )

        self.model = self.model.to(device)

    def set_model_params(self,
                         criterion1=nn.CrossEntropyLoss(),
                         criterion2=nn.BCEWithLogitsLoss(),
                         optimizer_name='Adam',
                         lr=0.1,
                         dropout_p=0.45,
                         model_name='DenseNet',
                         model_type='cv_transfer',
                         best_accuracy=0.,
                         best_validation_loss=None,
                         best_model_file='best_model_file.pth',
                         head1={
                             'num_outputs': 10,
                             'layers': [],
                             'model_type': 'classifier'
                         },
                         head2={
                             'num_outputs': 10,
                             'layers': [],
                             'model_type': 'muilti_label_classifier'
                         },
                         class_names=[],
                         num_classes=None,
                         ingredient_names=[],
                         num_ingredients=None):

        print(
            'Food Names: current best accuracy = {:.3f}'.format(best_accuracy))
        if best_validation_loss is not None:
            print('Food Ingredients: current best loss = {:.3f}'.format(
                best_validation_loss))

        super(FoodIngredients,
              self).set_model_params(optimizer_name=optimizer_name,
                                     lr=lr,
                                     dropout_p=dropout_p,
                                     model_name=model_name,
                                     model_type=model_type,
                                     best_accuracy=best_accuracy,
                                     best_validation_loss=best_validation_loss,
                                     best_model_file=best_model_file)
        self.class_names = class_names
        self.num_classes = num_classes
        self.ingredeint_names = ingredient_names
        self.num_ingredients = num_ingredients
        self.criterion1 = criterion1
        self.criterion2 = criterion2

    def forward(self, x):
        l = list(self.model.children())
        for m in l[:-2]:
            x = m(x)
        food = l[-2](x)
        ingredients = l[-1](x)
        return (food, ingredients)

    def compute_loss(self, outputs, labels, w1=1., w2=1.):
        out1, out2 = outputs
        label1, label2 = labels
        loss1 = self.criterion1(out1, label1)
        loss2 = self.criterion2(out2, label2)
        return [(loss1 * w1) + (loss2 * w2)]

    def freeze(self, train_classifier=True):
        super(FoodIngredients, self).freeze()
        if train_classifier:
            for param in self.model.fc1.parameters():
                param.requires_grad = True
            for param in self.model.fc2.parameters():
                param.requires_grad = True

    def parallelize(self):
        self.parallel = True
        self.model = DataParallelModel(self.model)
        self.criterion = DataParallelCriterion(self.criterion)

    def set_transfer_model(self,
                           mname,
                           pretrained=True,
                           add_extra=True,
                           dropout_p=0.45):
        self.model = None
        models_dict = {
            'densenet': {
                'model': models.densenet121(pretrained=pretrained),
                'conv_channels': 1024
            },
            'resnet34': {
                'model': models.resnet34(pretrained=pretrained),
                'conv_channels': 512
            },
            'resnet50': {
                'model': models.resnet50(pretrained=pretrained),
                'conv_channels': 2048
            }
        }
        meta = models_dict[mname.lower()]
        try:
            model = meta['model']
            for param in model.parameters():
                param.requires_grad = False
            self.model = model
            print(
                'Setting transfer learning model: self.model set to {}'.format(
                    mname))
        except:
            print(
                'Setting transfer learning model: model name {} not supported'.
                format(mname))

        # creating and adding extra layers to the model
        dream_model = None
        if add_extra:
            channels = meta['conv_channels']
            dream_model = nn.Sequential(
                nn.Conv2d(channels, channels, 3, 1, 1),
                # Printer(),
                nn.BatchNorm2d(channels),
                nn.ReLU(True),
                nn.Dropout2d(dropout_p),
                nn.Conv2d(channels, channels, 3, 1, 1),
                nn.BatchNorm2d(channels),
                nn.ReLU(True),
                nn.Dropout2d(dropout_p),
                nn.Conv2d(channels, channels, 3, 1, 1),
                nn.BatchNorm2d(channels),
                nn.ReLU(True),
                nn.Dropout2d(dropout_p))
        self.dream_model = dream_model

    def set_model_head(
            self,
            model_name='DenseNet',
            head1={
                'num_outputs': 10,
                'layers': [],
                'class_names': None,
                'model_type': 'classifier'
            },
            head2={
                'num_outputs': 10,
                'layers': [],
                'class_names': None,
                'model_type': 'muilti_label_classifier'
            },
            criterion1=nn.CrossEntropyLoss(),
            criterion2=nn.BCEWithLogitsLoss(),
            adaptive=True,
            dropout_p=0.45,
            device=None):

        models_meta = {
            'resnet34': {
                'conv_channels': 512,
                'head_id': -2,
                'adaptive_head': [DAI_AvgPool],
                'normal_head': [nn.AvgPool2d(7, 1)]
            },
            'resnet50': {
                'conv_channels': 2048,
                'head_id': -2,
                'adaptive_head': [DAI_AvgPool],
                'normal_head': [nn.AvgPool2d(7, 1)]
            },
            'densenet': {
                'conv_channels': 1024,
                'head_id': -1,
                'adaptive_head': [nn.ReLU(inplace=True), DAI_AvgPool],
                'normal_head': [nn.ReLU(inplace=True),
                                nn.AvgPool2d(7, 1)]
            }
        }

        name = model_name.lower()
        meta = models_meta[name]
        modules = list(self.model.children())
        l = modules[:meta['head_id']]
        if self.dream_model:
            l += self.dream_model
        heads = [head1, head2]
        crits = [criterion1, criterion2]
        fcs = []
        for head, criterion in zip(heads, crits):
            head['criterion'] = criterion
            if head['model_type'].lower() == 'classifier':
                head['output_non_linearity'] = None
            fc = modules[-1]
            try:
                in_features = fc.in_features
            except:
                in_features = fc.model.out.in_features
            fc = FC(num_inputs=in_features,
                    num_outputs=head['num_outputs'],
                    layers=head['layers'],
                    model_type=head['model_type'],
                    output_non_linearity=head['output_non_linearity'],
                    dropout_p=dropout_p,
                    criterion=head['criterion'],
                    optimizer_name=None,
                    device=device)
            fcs.append(fc)
        if adaptive:
            l += meta['adaptive_head']
        else:
            l += meta['normal_head']
        model = nn.Sequential(*l)
        model.add_module('fc1', fcs[0])
        model.add_module('fc2', fcs[1])
        self.model = model
        self.head1 = head1
        self.head2 = head2

        print('Multi-head set up complete.')

    def train_(self, e, trainloader, optimizer, print_every):

        epoch, epochs = e
        self.train()
        t0 = time.time()
        t1 = time.time()
        batches = 0
        running_loss = 0.
        for data_batch in trainloader:
            inputs, label1, label2 = data_batch[0], data_batch[1], data_batch[
                2]
            batches += 1
            inputs = inputs.to(self.device)
            label1 = label1.to(self.device)
            label2 = label2.to(self.device)
            labels = (label1, label2)
            optimizer.zero_grad()
            outputs = self.forward(inputs)
            loss = self.compute_loss(outputs, labels)[0]
            if self.parallel:
                loss.sum().backward()
                loss = loss.sum()
            else:
                loss.backward()
                loss = loss.item()
            optimizer.step()
            running_loss += loss
            if batches % print_every == 0:
                elapsed = time.time() - t1
                if elapsed > 60:
                    elapsed /= 60.
                    measure = 'min'
                else:
                    measure = 'sec'
                batch_time = time.time() - t0
                if batch_time > 60:
                    batch_time /= 60.
                    measure2 = 'min'
                else:
                    measure2 = 'sec'
                print(
                    '+----------------------------------------------------------------------+\n'
                    f"{time.asctime().split()[-2]}\n"
                    f"Time elapsed: {elapsed:.3f} {measure}\n"
                    f"Epoch:{epoch+1}/{epochs}\n"
                    f"Batch: {batches+1}/{len(trainloader)}\n"
                    f"Batch training time: {batch_time:.3f} {measure2}\n"
                    f"Batch training loss: {loss:.3f}\n"
                    f"Average training loss: {running_loss/(batches):.3f}\n"
                    '+----------------------------------------------------------------------+\n'
                )
                t0 = time.time()
        return running_loss / len(trainloader)

    def evaluate(self, dataloader, metric='accuracy'):

        running_loss = 0.
        classifier = None

        if self.model_type == 'classifier':  # or self.num_classes is not None:
            classifier = Classifier(self.class_names)

        y_pred = []
        y_true = []

        self.eval()
        rmse_ = 0.
        with torch.no_grad():
            for data_batch in dataloader:
                inputs, label1, label2 = data_batch[0], data_batch[
                    1], data_batch[2]
                inputs = inputs.to(self.device)
                label1 = label1.to(self.device)
                label2 = label2.to(self.device)
                labels = (label1, label2)
                outputs = self.forward(inputs)
                loss = self.compute_loss(outputs, labels)[0]
                if self.parallel:
                    running_loss += loss.sum()
                    outputs = parallel.gather(outputs, self.device)
                else:
                    running_loss += loss.item()
                if classifier is not None and metric == 'accuracy':
                    classifier.update_accuracies(outputs, labels)
                    y_true.extend(list(labels.squeeze(0).cpu().numpy()))
                    _, preds = torch.max(torch.exp(outputs), 1)
                    y_pred.extend(list(preds.cpu().numpy()))
                elif metric == 'rmse':
                    rmse_ += rmse(outputs, labels).cpu().numpy()

        self.train()

        ret = {}
        # print('Running_loss: {:.3f}'.format(running_loss))
        if metric == 'rmse':
            print('Total rmse: {:.3f}'.format(rmse_))
            ret['final_rmse'] = rmse_ / len(dataloader)

        ret['final_loss'] = running_loss / len(dataloader)

        if classifier is not None:
            ret['accuracy'], ret[
                'class_accuracies'] = classifier.get_final_accuracies()
            ret['report'] = classification_report(
                y_true, y_pred, target_names=self.class_names)
            ret['confusion_matrix'] = confusion_matrix(y_true, y_pred)
            try:
                ret['roc_auc_score'] = roc_auc_score(y_true, y_pred)
            except:
                pass
        return ret

    def evaluate_food(self, dataloader, metric='accuracy'):

        running_loss = 0.
        classifier = None

        classifier = Classifier(self.class_names)

        y_pred = []
        y_true = []

        self.eval()
        rmse_ = 0.
        with torch.no_grad():
            for data_batch in dataloader:
                inputs, labels = data_batch[0], data_batch[1]
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = self.forward(inputs)[0]
                if classifier is not None and metric == 'accuracy':
                    try:
                        classifier.update_accuracies(outputs, labels)
                        y_true.extend(list(labels.squeeze(0).cpu().numpy()))
                        _, preds = torch.max(torch.exp(outputs), 1)
                        y_pred.extend(list(preds.cpu().numpy()))
                    except:
                        pass
                elif metric == 'rmse':
                    rmse_ += rmse(outputs, labels).cpu().numpy()

        self.train()

        ret = {}
        # print('Running_loss: {:.3f}'.format(running_loss))
        if metric == 'rmse':
            print('Total rmse: {:.3f}'.format(rmse_))
            ret['final_rmse'] = rmse_ / len(dataloader)

        ret['final_loss'] = running_loss / len(dataloader)

        if classifier is not None:
            ret['accuracy'], ret[
                'class_accuracies'] = classifier.get_final_accuracies()
            ret['report'] = classification_report(
                y_true, y_pred, target_names=self.class_names)
            ret['confusion_matrix'] = confusion_matrix(y_true, y_pred)
            try:
                ret['roc_auc_score'] = roc_auc_score(y_true, y_pred)
            except:
                pass
        return ret

    def find_lr(self,
                trn_loader,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                plot=False):

        print('\nFinding the ideal learning rate.')

        model_state = copy.deepcopy(self.model.state_dict())
        optim_state = copy.deepcopy(self.optimizer.state_dict())
        optimizer = self.optimizer
        num = len(trn_loader) - 1
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        optimizer.param_groups[0]['lr'] = lr
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for data_batch in trn_loader:
            batch_num += 1
            inputs, label1, label2 = data_batch[0], data_batch[1], data_batch[
                2]
            inputs = inputs.to(self.device)
            label1 = label1.to(self.device)
            label2 = label2.to(self.device)
            labels = (label1, label2)
            optimizer.zero_grad()
            outputs = self.forward(inputs)
            loss = self.compute_loss(outputs, labels)[0]
            #Compute the smoothed loss
            if self.parallel:
                avg_loss = beta * avg_loss + (1 - beta) * loss.sum()
            else:
                avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > 4 * best_loss:
                self.log_lrs, self.find_lr_losses = log_lrs, losses
                self.model.load_state_dict(model_state)
                self.optimizer.load_state_dict(optim_state)
                if plot:
                    self.plot_find_lr()
                temp_lr = self.log_lrs[np.argmin(self.find_lr_losses) -
                                       (len(self.log_lrs) // 8)]
                self.lr = (10**temp_lr)
                print('Found it: {}\n'.format(self.lr))
                return self.lr
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            #Do the SGD step
            if self.parallel:
                loss.sum().backward()
            else:
                loss.backward()
            optimizer.step()
            #Update the lr for the next step
            lr *= mult
            optimizer.param_groups[0]['lr'] = lr

        self.log_lrs, self.find_lr_losses = log_lrs, losses
        self.model.load_state_dict(model_state)
        self.optimizer.load_state_dict(optim_state)
        if plot:
            self.plot_find_lr()
        temp_lr = self.log_lrs[np.argmin(self.find_lr_losses) -
                               (len(self.log_lrs) // 10)]
        self.lr = (10**temp_lr)
        print('Found it: {}\n'.format(self.lr))
        return self.lr

    def plot_find_lr(self):
        plt.ylabel("Loss")
        plt.xlabel("Learning Rate (log scale)")
        plt.plot(self.log_lrs, self.find_lr_losses)
        plt.show()

    def classify(self,
                 inputs,
                 thresh=0.4):  #,show = False,mean = None,std = None):
        outputs = self.predict(inputs)
        food, ing = outputs
        try:
            _, preds = torch.max(torch.exp(food), 1)
        except:
            _, preds = torch.max(torch.exp(food.unsqueeze(0)), 1)
        ing_outs = ing.sigmoid()
        ings = (ing_outs >= thresh)
        class_preds = [str(self.class_names[p]) for p in preds]
        ing_preds = [
            self.ingredeint_names[p.nonzero().squeeze(1).cpu()] for p in ings
        ]
        return class_preds, ing_preds

    def _get_dropout(self):
        return self.dropout_p

    def get_model_params(self):
        params = super(FoodIngredients, self).get_model_params()
        params['class_names'] = self.class_names
        params['num_classes'] = self.num_classes
        params['ingredient_names'] = self.ingredient_names
        params['num_ingredients'] = self.num_ingredients
        params['head1'] = self.head1
        params['head2'] = self.head2
        return params
示例#2
0
文件: trainer.py 项目: krikit/jigsaw
class Trainer:
    """
    trainer class
    """
    def __init__(self, cfg: Namespace, data: Dataset):
        """
        Args:
            cfg:  configuration
            data:  train dataset
        """
        self.cfg = cfg
        self.train, self.valid = data.split(0.8)
        RATING_FIELD.build_vocab(self.train)

        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')  # pylint: disable=no-member
        self.batch_size = cfg.batch_size
        if torch.cuda.is_available():
            self.batch_size *= torch.cuda.device_count()

        self.trn_itr = BucketIterator(
            self.train,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=True,
            train=True,
            sort_within_batch=True,
            sort_key=lambda exam: -len(exam.comment_text))
        self.vld_itr = BucketIterator(
            self.valid,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=False,
            train=False,
            sort_within_batch=True,
            sort_key=lambda exam: -len(exam.comment_text))
        self.log_step = 1000
        if len(self.vld_itr) < 100:
            self.log_step = 10
        elif len(self.vld_itr) < 1000:
            self.log_step = 100

        bert_path = cfg.bert_path if cfg.bert_path else 'bert-base-cased'
        self.model = BertForSequenceClassification.from_pretrained(
            bert_path, num_labels=2)
        pos_weight = (
            len([exam for exam in self.train.examples if exam.target < 0.5]) /
            len([exam for exam in self.train.examples if exam.target >= 0.5]))
        pos_wgt_tensor = torch.tensor([1.0, pos_weight], device=self.device)  # pylint: disable=not-callable
        self.criterion = nn.CrossEntropyLoss(weight=pos_wgt_tensor)
        if torch.cuda.is_available():
            self.model = DataParallelModel(self.model.cuda())
            self.criterion = DataParallelCriterion(self.criterion)
        self.optimizer = optim.Adam(self.model.parameters(), cfg.learning_rate)

    def run(self):
        """
        do train
        """
        max_f_score = -9e10
        max_epoch = -1
        for epoch in range(self.cfg.epoch):
            train_loss = self._train_epoch(epoch)
            metrics = self._evaluate(epoch)
            max_f_score_str = f' < {max_f_score:.2f}'
            if metrics['f_score'] > max_f_score:
                max_f_score_str = ' is max'
                max_f_score = metrics['f_score']
                max_epoch = epoch
                torch.save(self.model.state_dict(), self.cfg.model_path)
            logging.info('EPOCH[%d]: train loss: %.6f, valid loss: %.6f, acc: %.2f,' \
                         ' F: %.2f%s', epoch, train_loss, metrics['loss'],
                         metrics['accuracy'], metrics['f_score'], max_f_score_str)
            if (epoch - max_epoch) >= self.cfg.patience:
                logging.info('early stopping...')
                break
        logging.info('epoch: %d, f-score: %.2f', max_epoch, max_f_score)

    def _train_epoch(self, epoch: int) -> float:
        """
        train single epoch
        Args:
            epoch:  epoch number
        Returns:
            average loss
        """
        self.model.train()
        progress = tqdm(self.trn_itr,
                        f'EPOCH[{epoch}]',
                        mininterval=1,
                        ncols=100)
        losses = []
        for step, batch in enumerate(progress, start=1):
            outputs = self.model(batch.comment_text)
            # output of model wrapped with DataParallelModel is a list of outputs from each GPU
            # make input of DataParallelCriterion as a list of tuples
            if isinstance(self.model, DataParallelModel):
                loss = self.criterion([(output, ) for output in outputs],
                                      batch.target)
            else:
                loss = self.criterion(outputs, batch.target)
            losses.append(loss.item())
            if step % self.log_step == 0:
                avg_loss = sum(losses) / len(losses)
                progress.set_description(f'EPOCH[{epoch}] ({avg_loss:.6f})')
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
        return sum(losses) / len(losses)

    def _evaluate(self, epoch: int) -> Dict[str, float]:
        """
        evaluate on validation data
        Args:
            epoch:  epoch number
        Returns:
            metrics
        """
        self.model.eval()
        progress = tqdm(self.vld_itr,
                        f' EVAL[{epoch}]',
                        mininterval=1,
                        ncols=100)
        losses = []
        preds = []
        golds = []
        for step, batch in enumerate(progress, start=1):
            with torch.no_grad():
                outputs = self.model(batch.comment_text)
                if isinstance(self.model, DataParallelModel):
                    loss = self.criterion([(output, ) for output in outputs],
                                          batch.target)
                    for output in outputs:
                        preds.extend([(0 if o[0] < o[1] else 1)
                                      for o in output])
                else:
                    loss = self.criterion(outputs, batch.target)
                    preds.extend([(0 if output[0] < output[1] else 1)
                                  for output in outputs])
                losses.append(loss.item())
                golds.extend([gold.item() for gold in batch.target])
                if step % self.log_step == 0:
                    avg_loss = sum(losses) / len(losses)
                    progress.set_description(
                        f' EVAL[{epoch}] ({avg_loss:.6f})')
        metrics = self._get_metrics(preds, golds)
        metrics['loss'] = sum(losses) / len(losses)
        return metrics

    @classmethod
    def _get_metrics(cls, preds: List[float],
                     golds: List[float]) -> Dict[str, float]:
        """
        get metric values
        Args:
            preds:  predictions
            golds:  gold standards
        Returns:
            metric
        """
        assert len(preds) == len(golds)
        true_pos = 0
        false_pos = 0
        false_neg = 0
        true_neg = 0
        for pred, gold in zip(preds, golds):
            if pred >= 0.5:
                if gold >= 0.5:
                    true_pos += 1
                else:
                    false_pos += 1
            else:
                if gold >= 0.5:
                    false_neg += 1
                else:
                    true_neg += 1
        accuracy = (true_pos + true_neg) / (true_pos + false_pos + false_neg +
                                            true_neg)
        precision = 0.0
        if (true_pos + false_pos) > 0:
            precision = true_pos / (true_pos + false_pos)
        recall = 0.0
        if (true_pos + false_neg) > 0:
            recall = true_pos / (true_pos + false_neg)
        f_score = 0.0
        if (precision + recall) > 0.0:
            f_score = 2.0 * precision * recall / (precision + recall)
        return {
            'accuracy': 100.0 * accuracy,
            'precision': 100.0 * precision,
            'recall': 100.0 * recall,
            'f_score': 100.0 * f_score,
        }
示例#3
0
def train(config):
    net = BertForMaskedLM.from_pretrained(config.model)
    lossFunc = KLDivLoss(config)

    if torch.cuda.is_available():
        net = net.cuda()
        lossFunc = lossFunc.cuda()

        if config.dataParallel:
            net = DataParallelModel(net)
            lossFunc = DataParallelCriterion(lossFunc)

    options = optionsLoader(LOG, config.optionFrames, disp=False)
    Tokenizer = BertTokenizer.from_pretrained(config.model)
    prepareFunc = prepare_data

    trainSet = Dataset('train', config.batch_size,
                       lambda x: len(x[0]) + len(x[1]), prepareFunc, Tokenizer,
                       options['dataset'], LOG, 'train')
    validSet = Dataset('valid', config.batch_size,
                       lambda x: len(x[0]) + len(x[1]), prepareFunc, Tokenizer,
                       options['dataset'], LOG, 'valid')

    print(trainSet.__len__())

    Q = []
    best_vloss = 1e99
    counter = 0
    lRate = config.lRate

    prob_src = config.prob_src
    prob_tgt = config.prob_tgt

    num_train_optimization_steps = trainSet.__len__(
    ) * options['training']['stopConditions']['max_epoch']
    param_optimizer = list(net.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=lRate,
                         e=1e-9,
                         t_total=num_train_optimization_steps,
                         warmup=0.0)

    for epoch_idx in range(options['training']['stopConditions']['max_epoch']):
        total_seen = 0
        total_similar = 0
        total_unseen = 0
        total_source = 0

        trainSet.setConfig(config, prob_src, prob_tgt)
        trainLoader = data.DataLoader(dataset=trainSet,
                                      batch_size=1,
                                      shuffle=True,
                                      num_workers=config.dataLoader_workers,
                                      pin_memory=True)

        validSet.setConfig(config, 0.0, prob_tgt)
        validLoader = data.DataLoader(dataset=validSet,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=config.dataLoader_workers,
                                      pin_memory=True)

        for batch_idx, batch_data in enumerate(trainLoader):
            if (batch_idx + 1) % 10000 == 0:
                gc.collect()
            start_time = time.time()

            net.train()

            inputs, positions, token_types, labels, masks, batch_seen, batch_similar, batch_unseen, batch_source = batch_data

            inputs = inputs[0].cuda()
            positions = positions[0].cuda()
            token_types = token_types[0].cuda()
            labels = labels[0].cuda()
            masks = masks[0].cuda()
            total_seen += batch_seen
            total_similar += batch_similar
            total_unseen += batch_unseen
            total_source += batch_source

            n_token = int((labels.data != 0).data.sum())

            predicts = net(inputs, positions, token_types, masks)
            loss = lossFunc(predicts, labels, n_token).sum()

            Q.append(float(loss))
            if len(Q) > 200:
                Q.pop(0)
            loss_avg = sum(Q) / len(Q)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            LOG.log(
                'Epoch %2d, Batch %6d, Loss %9.6f, Average Loss %9.6f, Time %9.6f'
                % (epoch_idx + 1, batch_idx + 1, loss, loss_avg,
                   time.time() - start_time))

            # Checkpoints
            idx = epoch_idx * trainSet.__len__() + batch_idx + 1
            if (idx >= options['training']['checkingPoints']['checkMin']) and (
                    idx % options['training']['checkingPoints']['checkFreq']
                    == 0):
                if config.do_eval:
                    vloss = 0
                    total_tokens = 0
                    for bid, batch_data in enumerate(validLoader):
                        inputs, positions, token_types, labels, masks, batch_seen, batch_similar, batch_unseen, batch_source = batch_data

                        inputs = inputs[0].cuda()
                        positions = positions[0].cuda()
                        token_types = token_types[0].cuda()
                        labels = labels[0].cuda()
                        masks = masks[0].cuda()

                        n_token = int((labels.data != config.PAD).data.sum())

                        with torch.no_grad():
                            net.eval()
                            predicts = net(inputs, positions, token_types,
                                           masks)
                            vloss += float(lossFunc(predicts, labels).sum())

                        total_tokens += n_token

                    vloss /= total_tokens
                    is_best = vloss < best_vloss
                    best_vloss = min(vloss, best_vloss)
                    LOG.log(
                        'CheckPoint: Validation Loss %11.8f, Best Loss %11.8f'
                        % (vloss, best_vloss))

                    if is_best:
                        LOG.log('Best Model Updated')
                        save_check_point(
                            {
                                'epoch': epoch_idx + 1,
                                'batch': batch_idx + 1,
                                'options': options,
                                'config': config,
                                'state_dict': net.state_dict(),
                                'best_vloss': best_vloss
                            },
                            is_best,
                            path=config.save_path,
                            fileName='latest.pth.tar')
                        counter = 0
                    else:
                        counter += options['training']['checkingPoints'][
                            'checkFreq']
                        if counter >= options['training']['stopConditions'][
                                'rateReduce_bound']:
                            counter = 0
                            for param_group in optimizer.param_groups:
                                lr_ = param_group['lr']
                                param_group['lr'] *= 0.55
                                _lr = param_group['lr']
                            LOG.log(
                                'Reduce Learning Rate from %11.8f to %11.8f' %
                                (lr_, _lr))
                        LOG.log('Current Counter = %d' % (counter))

                else:
                    save_check_point(
                        {
                            'epoch': epoch_idx + 1,
                            'batch': batch_idx + 1,
                            'options': options,
                            'config': config,
                            'state_dict': net.state_dict(),
                            'best_vloss': 1e99
                        },
                        False,
                        path=config.save_path,
                        fileName='checkpoint_Epoch' + str(epoch_idx + 1) +
                        '_Batch' + str(batch_idx + 1) + '.pth.tar')
                    LOG.log('CheckPoint Saved!')

        if options['training']['checkingPoints']['everyEpoch']:
            save_check_point(
                {
                    'epoch': epoch_idx + 1,
                    'batch': batch_idx + 1,
                    'options': options,
                    'config': config,
                    'state_dict': net.state_dict(),
                    'best_vloss': 1e99
                },
                False,
                path=config.save_path,
                fileName='checkpoint_Epoch' + str(epoch_idx + 1) + '.pth.tar')

        LOG.log('Epoch Finished.')
        LOG.log(
            'Total Seen: %d, Total Unseen: %d, Total Similar: %d, Total Source: %d.'
            % (total_seen, total_unseen, total_similar, total_source))
        gc.collect()
示例#4
0
def main_tr(args, crossVal):
    dataLoad = ld.LoadData(args.data_dir, args.classes)
    data = dataLoad.processData(crossVal, args.data_name)

    # load the model
    model = net.MiniSeg(args.classes, aux=True)
    if not osp.isdir(osp.join(args.savedir + '_mod' + str(args.max_epochs))):
        os.mkdir(args.savedir + '_mod' + str(args.max_epochs))
    if not osp.isdir(
            osp.join(args.savedir + '_mod' + str(args.max_epochs),
                     args.data_name)):
        os.mkdir(
            osp.join(args.savedir + '_mod' + str(args.max_epochs),
                     args.data_name))
    saveDir = args.savedir + '_mod' + str(
        args.max_epochs) + '/' + args.data_name + '/' + args.model_name
    # create the directory if not exist
    if not osp.exists(saveDir):
        os.mkdir(saveDir)

    if args.gpu and torch.cuda.device_count() > 1:
        #model = torch.nn.DataParallel(model)
        model = DataParallelModel(model)
    if args.gpu:
        model = model.cuda()

    total_paramters = sum([np.prod(p.size()) for p in model.parameters()])
    print('Total network parameters: ' + str(total_paramters))

    # define optimization criteria
    weight = torch.from_numpy(
        data['classWeights'])  # convert the numpy array to torch
    if args.gpu:
        weight = weight.cuda()

    criteria = CrossEntropyLoss2d(weight, args.ignore_label)  #weight
    if args.gpu and torch.cuda.device_count() > 1:
        criteria = DataParallelCriterion(criteria)
    if args.gpu:
        criteria = criteria.cuda()

    # compose the data with transforms
    trainDataset_main = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(args.width, args.height),
        myTransforms.RandomCropResize(int(32. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])
    trainDataset_scale1 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 1.5), int(args.height * 1.5)),
        myTransforms.RandomCropResize(int(100. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])

    trainDataset_scale2 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 1.25), int(args.height * 1.25)),
        myTransforms.RandomCropResize(int(100. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])
    trainDataset_scale3 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 0.75), int(args.height * 0.75)),
        myTransforms.RandomCropResize(int(32. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])

    valDataset = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(args.width, args.height),
        myTransforms.ToTensor()
    ])

    # since we training from scratch, we create data loaders at different scales
    # so that we can generate more augmented data and prevent the network from overfitting
    trainLoader = torch.utils.data.DataLoader(myDataLoader.Dataset(
        data['trainIm'], data['trainAnnot'], transform=trainDataset_main),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              drop_last=True)

    trainLoader_scale1 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale1),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)

    trainLoader_scale2 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale2),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)
    trainLoader_scale3 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale3),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)

    valLoader = torch.utils.data.DataLoader(myDataLoader.Dataset(
        data['valIm'], data['valAnnot'], transform=valDataset),
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers,
                                            pin_memory=True)
    max_batches = len(trainLoader) + len(trainLoader_scale1) + len(
        trainLoader_scale2) + len(trainLoader_scale3)

    if args.gpu:
        cudnn.benchmark = True

    start_epoch = 0

    if args.pretrained is not None:
        state_dict = torch.load(args.pretrained)
        new_keys = []
        new_values = []
        for idx, key in enumerate(state_dict.keys()):
            if 'pred' not in key:
                new_keys.append(key)
                new_values.append(list(state_dict.values())[idx])
        new_dict = OrderedDict(list(zip(new_keys, new_values)))
        model.load_state_dict(new_dict, strict=False)
        print('pretrained model loaded')

    if args.resume is not None:
        if osp.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            args.lr = checkpoint['lr']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    log_file = osp.join(saveDir, 'trainValLog_' + args.model_name + '.txt')
    if osp.isfile(log_file):
        logger = open(log_file, 'a')
    else:
        logger = open(log_file, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write("\n%s\t%s\t\t%s\t%s\t%s\t%s\tlr" %
                     ('CrossVal', 'Epoch', 'Loss(Tr)', 'Loss(val)',
                      'mIOU (tr)', 'mIOU (val)'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=1e-4)
    maxmIOU = 0
    maxEpoch = 0
    print(args.model_name + '-CrossVal: ' + str(crossVal + 1))
    for epoch in range(start_epoch, args.max_epochs):
        # train for one epoch
        cur_iter = 0

        train(args, trainLoader_scale1, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale1)
        train(args, trainLoader_scale2, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale2)
        train(args, trainLoader_scale3, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale3)
        lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr, lr = \
                train(args, trainLoader, model, criteria, optimizer, epoch, max_batches, cur_iter)

        # evaluate on validation set
        lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = \
                val(args, valLoader, model, criteria)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lossTr': lossTr,
                'lossVal': lossVal,
                'iouTr': mIOU_tr,
                'iouVal': mIOU_val,
                'lr': lr
            },
            osp.join(
                saveDir, 'checkpoint_' + args.model_name + '_crossVal' +
                str(crossVal + 1) + '.pth.tar'))

        # save the model also
        model_file_name = osp.join(
            saveDir, 'model_' + args.model_name + '_crossVal' +
            str(crossVal + 1) + '_' + str(epoch + 1) + '.pth')
        torch.save(model.state_dict(), model_file_name)

        logger.write(
            "\n%d\t\t%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" %
            (crossVal + 1, epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val, lr))
        logger.flush()
        print("\nEpoch No. %d:\tTrain Loss = %.4f\tVal Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\n" \
                % (epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val))

        if mIOU_val >= maxmIOU:
            maxmIOU = mIOU_val
            maxEpoch = epoch + 1
        torch.cuda.empty_cache()
    logger.flush()
    logger.close()
    return maxEpoch, maxmIOU