Пример #1
0
def test(dataloader, model, model_path=None):
    if model_path:
        torch_load(model_path, model)
    model.eval()
    stats = collections.defaultdict(list)
    for batch_idx, data in enumerate(dataloader):
        logging.warning(f"Testing batch: {batch_idx+1}/{len(dataloader)}")
        fbank, seq_lens, tokens = data
        fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda()
        with torch.no_grad():
            loss = model(fbank, seq_lens, tokens)
        stats["loss_lst"].append(loss.item())
        if not hasattr(model, "module"):
            if model.acc is not None:
                stats["acc_lst"].append(model.acc)
                model.acc = None
        else:
            if model.module.acc is not None:
                stats["acc_lst"].append(model.module.acc)
                model.module.acc = None
    return dict_average(stats)
Пример #2
0
def train_approximator(**kwargs):
    parsing.set_log_level(kwargs['log_level'])

    if kwargs['deterministic']:
        if kwargs['seed'] is None:
            logger.warning(
                'Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    if kwargs['seed'] is not None:
        utils.set_seed(kwargs['seed'])

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                None,
                                False,
                                kwargs['masked_relu'],
                                True,
                                load_weights=False,
                                as_detector=True)
    model.train()

    train_dataset = parsing.parse_dataset(kwargs['domain'],
                                          kwargs['dataset'],
                                          allow_standard=False)
    val_dataset = None

    if kwargs['from_adversarial_dataset']:
        # TODO: Controllare
        train_dataset = train_dataset.to_distance_dataset()
    elif isinstance(train_dataset, ad.AdversarialDataset):
        raise click.BadArgumentUsage(
            'Expected a distance dataset as training dataset, got an adversarial dataset. '
            'If this is intentional, use --from-adversarial-dataset .')

    val_dataloader = None
    if kwargs['validation_split'] != 0:
        train_dataset, val_dataset = training.split_dataset(
            train_dataset, kwargs['validation_split'], shuffle=True)
    elif kwargs['validation_dataset'] is not None:
        val_dataset = parsing.parse_dataset(kwargs['domain'],
                                            kwargs['validation_dataset'],
                                            allow_standard=False)

        if kwargs['val_from_adversarial_dataset']:
            # TODO: controllare
            val_dataset = val_dataset.to_distance_dataset()
        elif isinstance(val_dataset, ad.AdversarialDataset):
            raise click.BadArgumentUsage(
                'Expected a distance dataset as validation dataset, got an adversarial dataset. '
                'If this is intentional, use --val-from-adversarial-dataset .')

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   kwargs['batch_size'],
                                                   shuffle=kwargs['shuffle'])

    if val_dataset is None:
        val_dataloader = None
    else:
        # There is no point in shuffling the validation dataset
        val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                     kwargs['batch_size'],
                                                     shuffle=False)

    early_stopping = None
    if kwargs['early_stopping'] > 0:
        early_stopping = training.EarlyStopping(
            kwargs['early_stopping'], delta=kwargs['early_stopping_delta'])

    # TODO: Mean or Sum?
    loss = torch.nn.MSELoss()
    optimiser = parsing.parse_optimiser(kwargs['optimiser'],
                                        model.parameters(), kwargs)

    if kwargs['checkpoint_every'] is None:
        checkpoint_path = None
    else:
        checkpoint_path = kwargs['save_to'] + '-checkpoint'

    if kwargs['load_checkpoint'] is None:
        loaded_checkpoint = None
    else:
        loaded_checkpoint = utils.torch_load(kwargs['load_checkpoint'])

    training.train(model,
                   train_dataloader,
                   optimiser,
                   loss,
                   kwargs['epochs'],
                   kwargs['device'],
                   val_loader=val_dataloader,
                   l1_regularization=kwargs['l1_regularization'],
                   early_stopping=early_stopping,
                   checkpoint_every=kwargs['checkpoint_every'],
                   checkpoint_path=checkpoint_path,
                   loaded_checkpoint=loaded_checkpoint,
                   choose_best=kwargs['choose_best'])

    save_to = kwargs['save_to']
    pathlib.Path(save_to).parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), save_to)
Пример #3
0
def test(epoch,
         dataloader,
         model,
         model_path=None,
         language=None,
         visualize_sim_adapter=False):
    if model_path:
        torch_load(model_path, model)
    orig_model = None
    if hasattr(model, "module"):
        orig_model = model
        model = model.module
    model.eval()
    stats = collections.defaultdict(list)
    for batch_idx, data in enumerate(dataloader):
        logging.warning(f"Testing batch: {batch_idx+1}/{len(dataloader)}")
        if len(data) == 4:
            fbank, seq_lens, tokens, language = data
        else:
            assert language is not None
            fbank, seq_lens, tokens = data
        fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda()
        with torch.no_grad():
            loss = model(fbank, seq_lens, tokens, language)

        if visualize_sim_adapter:
            atts = model.calculate_sim_adapter_attentions(
                fbank, seq_lens, tokens, language)
            init_mat = lambda: np.zeros((len(model.fusion_languages), ))
            avg_atts = collections.defaultdict(init_mat)
            count = collections.defaultdict(int)
            for key in atts.keys():
                avg_atts[key] = avg_atts[key] + atts[key].sum(axis=(0, 1))
                count[
                    key] = count[key] + atts[key].shape[0] * atts[key].shape[1]
        stats["loss_lst"].append(loss.item())
        if not hasattr(model, "module"):
            if model.acc is not None:
                stats["acc_lst"].append(model.acc)
                model.acc = None
        else:
            if model.module.acc is not None:
                stats["acc_lst"].append(model.module.acc)
                model.module.acc = None
    if visualize_sim_adapter:
        for key in avg_atts.keys():
            avg_atts[key] = avg_atts[key] / count[key]
            logging.warning(f"Attention scores of {key}: {avg_atts[key]}")
        fig = plt.figure(figsize=(16, 8))
        ax = fig.subplots()
        atts, labels = [], []
        for key in avg_atts.keys():
            atts.append(avg_atts[key])
            labels.append(key)
        atts = np.stack(atts)
        tick_marks = np.arange(len(labels))
        ax.set_yticks(tick_marks)
        ax.set_yticklabels(labels)
        x_labels = list(sorted(model.fusion_languages))
        ax.set_xticks(np.arange(len(x_labels)))
        ax.set_xticklabels(x_labels)
        ax.imshow(atts)
        import itertools
        for i, j in itertools.product(range(atts.shape[0]),
                                      range(atts.shape[1])):
            plt.text(j,
                     i,
                     "{:0.2f}".format(atts[i, j]),
                     horizontalalignment="center",
                     color="white")
        fig.tight_layout()
        fig.savefig(f"{args.outdir}/att_{epoch}.png")
        plt.close()
    if orig_model is not None:
        model = orig_model
    return dict_average(stats)
Пример #4
0
        optimizer = {}
        for lang in args.adapter_train_languages:
            for name, parameter in model.named_parameters():
                if parameter.requires_grad and lang in name.split("."):
                    model_params[lang].append(parameter)
            logging.warning(
                f"Number of trainable parameters for language {lang} " +
                str(sum(p.numel() for p in model_params[lang])))
            optimizer[lang] = torch.optim.Adam(model_params[lang],
                                               lr=args.adam_lr,
                                               weight_decay=args.weight_decay)

    # Resume from a snapshot
    if args.resume:
        logging.warning("resumed from %s" % args.resume)
        torch_load(args.resume, model, optimizer)
        setattr(args, "start_epoch", int(args.resume.split('.')[-1]) + 1)
    else:
        setattr(args, "start_epoch", 1)

    if args.load_pretrained_model:
        model_path, modules_to_load, exclude_modules = args.load_pretrained_model.split(
            ":")
        logging.warning("load pretrained model from %s" %
                        args.load_pretrained_model)
        load_pretrained_model(model=model,
                              model_path=model_path,
                              modules_to_load=modules_to_load,
                              exclude_modules=exclude_modules)
    if args.load_head_from_pretrained_model:
        logging.warning("load pretrained model head from %s" %
Пример #5
0
def train_classifier(**kwargs):
    parsing.set_log_level(kwargs['log_level'])
    logger.debug('Running train-classifier command with kwargs %s', kwargs)

    if kwargs['deterministic']:
        if kwargs['seed'] is None:
            logger.warning(
                'Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    if kwargs['seed'] is not None:
        utils.set_seed(kwargs['seed'])

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                None,
                                False,
                                kwargs['masked_relu'],
                                True,
                                load_weights=False)
    model.train()

    extra_transforms = []

    if kwargs['flip']:
        extra_transforms.append(torchvision.transforms.RandomHorizontalFlip())

    if kwargs['rotation'] != 0 or kwargs['translation'] != 0:
        if kwargs['translation'] < 0 or kwargs['translation'] > 1:
            logger.warning('The suggested range for --translation is [0, 1].')

        if kwargs['rotation'] < 0 or kwargs['rotation'] > 180:
            logger.warning('The suggested range for --rotation is [0, 180].')

        translation = (
            kwargs['translation'],
            kwargs['translation']) if kwargs['translation'] != 0 else None
        extra_transforms.append(
            torchvision.transforms.RandomAffine(kwargs['rotation'],
                                                translation))

    train_dataset = parsing.parse_dataset(kwargs['domain'],
                                          kwargs['dataset'],
                                          extra_transforms=extra_transforms)

    # Validation
    val_dataset = None

    if kwargs['validation_dataset'] is not None and kwargs[
            'validation_split'] != 0:
        raise click.BadOptionUsage(
            '--validation_split',
            '--validation_dataset and validation_split are mutually exclusive.'
        )

    if kwargs['validation_split'] != 0:
        logger.debug('Performing a validation split.')
        train_dataset, val_dataset = training.split_dataset(
            train_dataset, kwargs['validation_split'], shuffle=True)
    elif kwargs['validation_dataset'] is not None:
        logger.debug('Loading an existing validation dataset.')
        val_dataset = parsing.parse_dataset(kwargs['domain'],
                                            kwargs['validation_dataset'],
                                            allow_standard=True)

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   kwargs['batch_size'],
                                                   shuffle=kwargs['shuffle'])
    if val_dataset is None:
        val_dataloader = None
    else:
        val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                     kwargs['batch_size'],
                                                     shuffle=False)

    # Early stopping
    early_stopping = None
    if kwargs['early_stopping'] > 0:
        if kwargs['choose_best'] and kwargs['early_stopping_delta'] != 0:
            logger.warning(
                'Received --choose-best and --early-stopping with delta != 0. '
                'Remember that with delta != 0, --choose-best and --early-stopping '
                'track differently the best loss and state_dict.')
        logger.debug('Adding early stopping.')
        early_stopping = training.EarlyStopping(
            kwargs['early_stopping'], delta=kwargs['early_stopping_delta'])

    # Adversarial training
    if kwargs['adversarial_training'] == []:
        adversarial_attack = None

        if kwargs['adversarial_ratio'] is not None:
            logger.warning(
                'Received --adversarial-ratio without --adversarial-training.')
        if kwargs['adversarial_p'] is not None:
            logger.warning(
                'Received --adversarial-p without --adversarial-training.')
        if kwargs['adversarial_eps'] is not None:
            logger.warning(
                'Received --adversarial-eps without --adversarial-training.')
        if kwargs['adversarial_eps_growth_epoch'] != 0:
            logger.warning(
                'Received --adversarial-eps-growth-epoch without --adversarial-training.'
            )
        if kwargs['adversarial_eps_growth_start'] is not None:
            logger.warning(
                'Received --adversarial-eps-growth-start without --adversarial-training.'
            )
    else:
        logger.debug('Enabling adversarial training.')

        if kwargs['adversarial_ratio'] is None:
            raise click.BadOptionUsage(
                '--adversarial-ratio',
                'Please specify the ratio for adversarial training with --adversarial-ratio.'
            )

        if kwargs['adversarial_ratio'] <= 0 or kwargs['adversarial_ratio'] > 1:
            raise click.BadOptionUsage(
                '--adversarial-ratio',
                '--adversarial-ratio must be between 0 (exclusive) and 1 (inclusive).'
            )

        if kwargs['adversarial_p'] is None:
            raise click.BadOptionUsage(
                '--adversarial-p',
                'Please specify the Lp norm for adversarial training with --adversarial-p.'
            )

        if kwargs['adversarial_eps'] is None:
            raise click.BadOptionUsage(
                '--adversarial-eps',
                'Please specify the maximum perturbarion norm for adversarial training with --adversarial-eps (inf is also allowed).'
            )

        if kwargs['adversarial_eps_growth_epoch'] > 0:
            if kwargs['adversarial_eps_growth_start'] is None:
                raise click.BadOptionUsage(
                    '--adversarial-eps-growth-start',
                    'Please specify the initial value for adversarial epsilon growth with --adversarial-eps-growth-start '
                    '(0 is also allowed).')

            if kwargs['early_stopping'] > 0:
                logger.warning(
                    'Received --adversarial-eps-growth-epoch and --early-stopping together.'
                )
        elif kwargs['adversarial_eps_growth_start'] is not None:
            logger.warning(
                'Received --adversarial-eps-growth-start without --adversarial-eps-growth-epoch.'
            )

        attack_config = utils.read_attack_config_file(
            kwargs['attack_config_file'])

        adversarial_attack = parsing.parse_attack_pool(
            kwargs['adversarial_training'],
            kwargs['domain'],
            kwargs['adversarial_p'],
            'training',
            model,
            attack_config,
            kwargs['device'],
            seed=kwargs['seed'])

    # RS loss
    if kwargs['rs_regularization'] == 0:
        if kwargs['rs_eps'] is not None:
            logger.warning('Received --rs-eps without --rs-regularization.')
        if kwargs['rs_start_epoch'] != 1:
            logger.warning(
                'Received --rs-start_epoch without --rs-regularization.')
    else:
        if kwargs['rs_eps'] is None:
            raise click.BadOptionUsage(
                '--rs-eps',
                'Please specify the maximum perturbation for RS loss with --rs-eps.'
            )

        if kwargs['rs_start_epoch'] > kwargs['epochs']:
            logger.warning(
                '--rs-start-epoch is higher than the number of epochs. This means that RS loss will never be activated.'
            )

        if kwargs['rs_start_epoch'] > 1 and kwargs['early_stopping'] > 0:
            logger.warning(
                'Received --rs-start-epoch and --early-stopping together.')

    # Use Mean Cross Entropy, consistent with Xiao and Madry's ReLU training technique
    loss = torch.nn.CrossEntropyLoss(reduction='mean')
    optimiser = parsing.parse_optimiser(kwargs['optimiser'],
                                        model.parameters(), kwargs)

    if kwargs['checkpoint_every'] is None:
        checkpoint_path = None
    else:
        checkpoint_path = kwargs['save_to'] + '-checkpoint'

    if kwargs['load_checkpoint'] is None:
        loaded_checkpoint = None
    else:
        loaded_checkpoint = utils.torch_load(kwargs['load_checkpoint'])

    training.train(
        model,
        train_dataloader,
        optimiser,
        loss,
        kwargs['epochs'],
        kwargs['device'],
        val_loader=val_dataloader,
        l1_regularization=kwargs['l1_regularization'],
        rs_regularization=kwargs['rs_regularization'],
        rs_eps=kwargs['rs_eps'],
        rs_minibatch_size=kwargs['rs_minibatch'],
        rs_start_epoch=kwargs['rs_start_epoch'],
        early_stopping=early_stopping,
        attack=adversarial_attack,
        attack_ratio=kwargs['adversarial_ratio'],
        attack_p=kwargs['adversarial_p'],
        attack_eps=kwargs['adversarial_eps'],
        attack_eps_growth_epoch=kwargs['adversarial_eps_growth_epoch'],
        attack_eps_growth_start=kwargs['adversarial_eps_growth_start'],
        checkpoint_every=kwargs['checkpoint_every'],
        checkpoint_path=checkpoint_path,
        loaded_checkpoint=loaded_checkpoint,
        choose_best=kwargs['choose_best'])

    save_to = kwargs['save_to']
    pathlib.Path(save_to).parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), save_to)