Exemplo n.º 1
0
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: VariationalEncoder,
                 decoder: VariationalDecoder,
                 generator: Model,
                 discriminator: Model,
                 mse_weight: float = 2.0,
                 train_temperature: float = 1.0,
                 inference_temperature: float = 1e-5,
                 num_responses: int = 10) -> None:
        super().__init__(vocab)
        self._encoder = encoder
        self._decoder = decoder
        self._mse_weight = mse_weight
        self.train_temperature = train_temperature
        self.inference_temperature = inference_temperature
        self._num_responses = num_responses
        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token)  # pylint: disable=protected-access
        self.s_bleu4 = NLTKSentenceBLEU(
            n_hyps=self._num_responses,
            smoothing_function=SmoothingFunction().method7,
            exclude_indices={
                self._pad_index, self._end_index, self._start_index
            },
            prefix='_S_BLEU4')
        self.n_bleu2 = NLTKSentenceBLEU(ngram_weights=(1 / 2, 1 / 2),
                                        n_hyps=self._num_responses,
                                        exclude_indices={
                                            self._pad_index, self._end_index,
                                            self._start_index
                                        },
                                        prefix='_BLEU2')

        # We need our optimizer to know which parameters came from
        # which model, so we cheat by adding tags to them.
        for param in generator.parameters():
            setattr(param, '_generator', True)
        for param in discriminator.parameters():
            setattr(param, '_discriminator', True)

        self.generator = generator
        self.discriminator = discriminator
        self._disc_metrics = {
            "dfl": Average(),
            "dfacc": Average(),
            "drl": Average(),
            "dracc": Average(),
        }

        self._gen_metrics = {
            "_gl": Average(),
            "gce": Average(),
            "_gmse": Average(),
            "_mean": Average(),
            "_stdev": Average()
        }
Exemplo n.º 2
0
def enable_gradient_clipping(model: Model, grad_clipping: Optional[float]) -> None:
    if grad_clipping is not None:
        for parameter in model.parameters():
            if parameter.requires_grad:
                parameter.register_hook(lambda grad: nn_util.clamp_tensor(grad,
                                                                          minimum=-grad_clipping,
                                                                          maximum=grad_clipping))
Exemplo n.º 3
0
def rescale_gradients(model: Model, grad_norm: Optional[float] = None) -> Optional[float]:
    """
    Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.
    """
    if grad_norm:
        parameters_to_clip = [p for p in model.parameters() if p.grad is not None]
        return sparse_clip_norm(parameters_to_clip, grad_norm)
    return None
Exemplo n.º 4
0
    def from_params(cls, params: Params,
                    model: Model) -> 'UpdateMovingAverage':  # type: ignore
        # pylint: disable=arguments-differ
        moving_average_params = params.pop("moving_average")
        moving_average = MovingAverage.from_params(
            params=moving_average_params, parameters=model.parameters())

        return UpdateMovingAverage(moving_average)
Exemplo n.º 5
0
    def from_params(cls, model: Model, serialization_dir: str,
                    iterator: DataIterator, iterator_aux: DataIterator,
                    train_dataset: Dataset, train_dataset_aux: Dataset,
                    mixing_ratio: float, cutoff_epoch: int,
                    validation_dataset: Optional[Dataset],
                    validation_dataset_aux: Optional[Dataset], params: Params,
                    files_to_archive: Dict[str, str]) -> 'MultiTaskTrainer':

        patience = params.pop("patience", 2)
        validation_metric = params.pop("validation_metric", "-loss")
        num_epochs = params.pop("num_epochs", 20)
        cuda_device = params.pop("cuda_device", -1)
        grad_norm = params.pop("grad_norm", None)
        grad_clipping = params.pop("grad_clipping", None)
        lr_scheduler_params = params.pop("learning_rate_scheduler", None)

        if cuda_device >= 0:
            model = model.cuda(cuda_device)
        parameters = [p for p in model.parameters() if p.requires_grad]
        optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))

        if lr_scheduler_params:
            scheduler = LearningRateScheduler.from_params(
                optimizer, lr_scheduler_params)
        else:
            scheduler = None
        no_tqdm = params.pop("no_tqdm", False)

        params.assert_empty(cls.__name__)
        return MultiTaskTrainer(model=model,
                                optimizer=optimizer,
                                iterator=iterator,
                                iterator_aux=iterator_aux,
                                train_dataset=train_dataset,
                                train_dataset_aux=train_dataset_aux,
                                mixing_ratio=mixing_ratio,
                                cutoff_epoch=cutoff_epoch,
                                validation_dataset=validation_dataset,
                                validation_dataset_aux=validation_dataset_aux,
                                patience=patience,
                                validation_metric=validation_metric,
                                num_epochs=num_epochs,
                                serialization_dir=serialization_dir,
                                files_to_archive=files_to_archive,
                                cuda_device=cuda_device,
                                grad_norm=grad_norm,
                                grad_clipping=grad_clipping,
                                learning_rate_scheduler=scheduler,
                                no_tqdm=no_tqdm)
Exemplo n.º 6
0
    def __init__(self,
                 rpn: Model,
                 train_rpn: bool = False,
                 pooler_sampling_ratio: int = 2,
                 decoder_thresh: float = 0.1,
                 decoder_nms_thresh: float = 0.5,
                 decoder_detections_per_image: int = 100,
                 matcher_high_thresh: float = 0.5,
                 matcher_low_thresh: float = 0.5,
                 allow_low_quality_matches: bool = True,
                 batch_size_per_image: int = 64,
                 balance_sampling_fraction: float = 0.25):
        feedforward = FeedForward(7 * 7 * 256, 2, [1024, 1024], nn.ReLU())
        encoder = FlattenEncoder(256, 7, 7, feedforward)
        vocab = Vocabulary({
            'labels': {k: 1
                       for k in PretrainedDetectronFasterRCNN.CATEGORIES}
        })
        box_roi_head = FasterRCNNROIHead(
            encoder,
            pooler_resolution=7,
            pooler_sampling_ratio=pooler_sampling_ratio,
            matcher_low_thresh=matcher_low_thresh,
            matcher_high_thresh=matcher_high_thresh,
            decoder_thresh=decoder_thresh,
            decoder_nms_thresh=decoder_nms_thresh,
            decoder_detections_per_image=decoder_detections_per_image,
            allow_low_quality_matches=allow_low_quality_matches,
            batch_size_per_image=batch_size_per_image,
            balance_sampling_fraction=balance_sampling_fraction)
        super(PretrainedDetectronFasterRCNN,
              self).__init__(vocab, rpn, box_roi_head, train_rpn=train_rpn)
        frcnn = fasterrcnn_resnet50_fpn(pretrained=True,
                                        pretrained_backbone=True)
        self._box_classifier.load_state_dict(
            frcnn.roi_heads.box_predictor.cls_score.state_dict())
        self._bbox_pred.load_state_dict(
            frcnn.roi_heads.box_predictor.bbox_pred.state_dict())

        # pylint: disable = protected-access
        feedforward._linear_layers[0].load_state_dict(
            frcnn.roi_heads.box_head.fc6.state_dict())
        feedforward._linear_layers[1].load_state_dict(
            frcnn.roi_heads.box_head.fc7.state_dict())

        for p in rpn.parameters():
            p.requires_grad = train_rpn
Exemplo n.º 7
0
    def from_params(cls, model: Model, serialization_dir: str,
                    iterator: DataIterator, train_dataset: Dataset,
                    validation_dataset: Optional[Dataset],
                    params: Params) -> 'Trainer':

        patience = params.pop("patience", 2)
        validation_metric = params.pop("validation_metric", "-loss")
        num_epochs = params.pop("num_epochs", 20)
        cuda_device = params.pop("cuda_device", -1)
        grad_norm = params.pop("grad_norm", None)
        grad_clipping = params.pop("grad_clipping", None)
        lr_scheduler_params = params.pop("learning_rate_scheduler", None)

        if cuda_device >= 0:
            model = model.cuda(cuda_device)
        parameters = [p for p in model.parameters() if p.requires_grad]
        optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))

        if lr_scheduler_params:
            scheduler = LearningRateScheduler.from_params(
                optimizer, lr_scheduler_params)
        else:
            scheduler = None
        no_tqdm = params.pop("no_tqdm", False)

        params.assert_empty(cls.__name__)
        return Trainer(model,
                       optimizer,
                       iterator,
                       train_dataset,
                       validation_dataset,
                       patience=patience,
                       validation_metric=validation_metric,
                       num_epochs=num_epochs,
                       serialization_dir=serialization_dir,
                       cuda_device=cuda_device,
                       grad_norm=grad_norm,
                       grad_clipping=grad_clipping,
                       learning_rate_scheduler=scheduler,
                       no_tqdm=no_tqdm)