Пример #1
0
    ComputeMelSpectrogramFromMagSpectrogram(num_features=num_features,
                                            normalize=args.normalize,
                                            eps=eps)
])

if args.dataset == 'librispeech':
    from datasets.libri_speech import LibriSpeech as SpeechDataset, vocab

    max_duration = 16.7
    train_dataset = ConcatDataset([
        SpeechDataset(name='train-clean-100',
                      max_duration=max_duration,
                      transform=train_transform),
        SpeechDataset(name='train-clean-360',
                      max_duration=max_duration,
                      transform=train_transform),
        SpeechDataset(name='train-other-500',
                      max_duration=max_duration,
                      transform=train_transform),
        ColoredNoiseDataset(size=5000, transform=train_transform),
        BackgroundSounds(size=1000, transform=train_transform)
    ])
    valid_dataset = SpeechDataset(name='dev-clean', transform=valid_transform)
elif args.dataset == 'bolorspeech':
    from datasets.bolor_speech import BolorSpeech as SpeechDataset, vocab

    max_duration = 16.7
    train_dataset = ConcatDataset([
        SpeechDataset(name='train',
                      max_duration=max_duration,
                      transform=train_transform),
Пример #2
0
 def test_concat_two_singletons(self):
     result = ConcatDataset([[0], [1]])
     self.assertEqual(2, len(result))
     self.assertEqual(0, result[0])
     self.assertEqual(1, result[1])
Пример #3
0
 def test_concat_two_non_singletons_with_empty(self):
     # Adding an empty dataset somewhere is correctly handled
     result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]])
     self.assertEqual(10, len(result))
     self.assertEqual(0, result[0])
     self.assertEqual(5, result[5])
Пример #4
0
        for name, param in model.named_parameters():
            missing_keys.add(name)
        for key, data in state_dict.items():
            if key in missing_keys:
                matched_state_dict[key] = data
                missing_keys.remove(key)
            else:
                unexpected_keys.add(key)
        print("Unexpected_keys:", list(unexpected_keys))
        print("Missing_keys:", list(missing_keys))
        model.load_state_dict(matched_state_dict, strict=False)

    # use train & val splits to optimize, only available for vqa, not vqa_cp
    if args.use_both and args.dataset == "vqa":
        length = len(val_dset)
        trainval_concat_dset = ConcatDataset([train_dset, val_dset])
        if args.use_vg or args.use_visdial:
            trainval_concat_dsets_split = random_split(trainval_concat_dset, [
                int(0.2 * length),
                len(trainval_concat_dset) - int(0.2 * length)
            ])
        else:
            trainval_concat_dsets_split = random_split(trainval_concat_dset, [
                int(0.1 * length),
                len(trainval_concat_dset) - int(0.1 * length)
            ])
        concat_list = [trainval_concat_dsets_split[1]]

        # use a portion of Visual Genome dataset
        if args.use_vg:
            vg_train_dset = VisualGenomeFeatureDataset(
def generate_random_dataset(path,
                            nb_row_valid,
                            nb_rows_test,
                            nb_rows,
                            dict_nb_lignes,
                            size_image=224,
                            encoding_dict=None,
                            filenames=None,
                            use_acc_proportionate_sampling=False):
    '''

    Pour chaque classe dans filenames, on prend nb_rows données aléatoire dans le fichier

    :param path:
    :param nb_row_valid:
    :param nb_rows_test:
    :param nb_rows:
    :param size_image:
    :param encoding_dict:
    :param filenames:
    :return:
    '''

    if filenames == None:
        filenames = os.listdir(path)

    if use_acc_proportionate_sampling:
        if os.path.isfile("saves_obj/dict_acc_per_class_valid.pk"):
            dict_acc_class = load_object(
                "saves_obj/dict_acc_per_class_valid.pk")
        else:
            print(
                "Aucun dictionnaire d'accuracy par classe trouvé; sampling uniforme utilisé"
            )

    nb_lignes_skip = nb_row_valid + nb_rows_test
    list_dataset = []

    dict_nb_row_used_per_class = {}

    for fn in filenames:
        n = dict_nb_lignes[fn]
        skip = list(range(1, nb_lignes_skip)) + sorted(
            random.sample(range(nb_lignes_skip, n),
                          n - nb_rows - nb_lignes_skip))

        if use_acc_proportionate_sampling:
            acc = dict_acc_class[fn]
            new_rows = round((1.1 - acc) * nb_rows)

        else:
            new_rows = nb_rows
        dict_nb_row_used_per_class[fn] = new_rows

        data_set = DoodlesDataset(fn,
                                  path,
                                  nrows=new_rows,
                                  size=size_image,
                                  skiprows=skip,
                                  encoding_dict=encoding_dict,
                                  mode="train")
        list_dataset.append(data_set)

    doodles = ConcatDataset(list_dataset)

    print("Nombre de données d'entraînement:", dict_nb_row_used_per_class)

    return doodles
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        required=True,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        required=True,
        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )
    parser.add_argument(
        "--dataset",
        default="squad",
        type=str,
        help="Name of dataset; choices between SQuAD, NewsQA and HotPotQA")
    parser.add_argument("--features_file",
                        type=str,
                        help="Load features of original train set")

    ## Other parameters
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")

    parser.add_argument(
        '--version_2_with_negative',
        action='store_true',
        help=
        'If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument(
        '--null_score_diff_threshold',
        type=float,
        default=0.0,
        help=
        "If null_score - best_non_null is greater than the threshold predict null."
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument(
        "--early_stopping",
        default=-1,
        type=int,
        help=
        "If > 0: use that many evaluations on dev set to stop training if dev F1 score does not improve."
    )
    parser.add_argument(
        "--warmup_ratio",
        default=0,
        type=float,
        help="Linear warmup over [warmup_ratio*total_steps] steps.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json output file."
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")

    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    #if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
    #    raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    HP_TRANSFORMS = [
        'AddSentDiverse', 'AddKSentDiverse', 'AddAnswerPosition',
        'InvalidateAnswer', 'PerturbAnswer', 'AddSentDiverse-PerturbAnswer',
        'AddKSentDiverse-PerturbAnswer', 'AddAnswerPosition-PerturbAnswer'
    ]

    # Create output directory if needed
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir='/ssd-playpen/home/adyasha/cache/')
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir='/ssd-playpen/home/adyasha/cache/')
    train_dataset = load_and_cache_examples(args,
                                            tokenizer,
                                            evaluate=False,
                                            output_examples=False)

    for i, p in enumerate([0, 0.2, 0.4, 0.6, 0.8, 1.0]):

        if i == 0:
            continue
        policy = []
        for t in HP_TRANSFORMS:
            policy += [t, p]
        assert len(policy) % 2 == 0

        if i != 0:
            augmented_data_path = augment_with_adversaries(policy,
                                                           args.train_file,
                                                           prefix='Rand%s' % i,
                                                           keep_original=False)

        model = model_class.from_pretrained(
            args.model_name_or_path,
            from_tf=bool('.ckpt' in args.model_name_or_path),
            config=config,
            cache_dir='/ssd-playpen/home/adyasha/cache/')

        if args.local_rank == 0:
            torch.distributed.barrier(
            )  # Make sure only the first process in distributed training will download model & vocab

        model.to(args.device)

        logger.info("Training/evaluation parameters %s", args)

        # Training
        if args.do_train:

            if i != 0:
                augment_dataset = load_and_cache_examples(
                    args,
                    tokenizer,
                    evaluate=False,
                    output_examples=False,
                    input_data_file=augmented_data_path)
                # if args.features_file is not None:
                #     original_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False,
                #                                                from_saved=args.features_file)
                train_dataset = ConcatDataset([train_dataset, augment_dataset])
            global_step, tr_loss = train(args, train_dataset, model, tokenizer)
            logger.info(" global_step = %s, average loss = %s", global_step,
                        tr_loss)

        # Save the trained model and the tokenizer
        if args.local_rank == -1 or torch.distributed.get_rank() == 0:

            logger.info("Saving model checkpoint to %s", args.output_dir)
            # Save a trained model, configuration and tokenizer using `save_pretrained()`.
            # They can then be reloaded using `from_pretrained()`
            model_to_save = model.module if hasattr(
                model, 'module'
            ) else model  # Take care of distributed/parallel training
            model_to_save.save_pretrained(args.output_dir)
            tokenizer.save_pretrained(args.output_dir)

            # Good practice: save your training arguments together with the trained model
            torch.save(args, os.path.join(args.output_dir,
                                          'training_args.bin'))

            # Load a trained model and vocabulary that you have fine-tuned
            model = model_class.from_pretrained(args.output_dir)
            tokenizer = tokenizer_class.from_pretrained(
                args.output_dir, cache_dir='/ssd-playpen/home/adyasha/cache')
            model.to(args.device)

        # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
        results = {}
        if args.do_eval and args.local_rank in [-1, 0]:
            checkpoints = [args.output_dir]
            checkpoint_dir = os.path.dirname(args.output_dir)
            print(checkpoint_dir)

            if args.eval_all_checkpoints:
                checkpoints = list(
                    os.path.dirname(c) for c in sorted(
                        glob.glob(checkpoint_dir + '/**/' + WEIGHTS_NAME,
                                  recursive=True)))
                logging.getLogger(
                    "pytorch_transformers.modeling_utils").setLevel(
                        logging.WARN)  # Reduce model loading logs

            logger.info("Evaluate the following checkpoints: %s", checkpoints)

            f_results = open(os.path.join('results-' + str(i) + '.json'), 'a+')

            dev_dataset, dev_examples, dev_features = load_and_cache_examples(
                args, tokenizer, evaluate=True, output_examples=True)

            for checkpoint in checkpoints:
                # Reload the model
                global_step = checkpoint.split(
                    '-')[-1] if len(checkpoints) > 1 else ""
                model = model_class.from_pretrained(checkpoint)
                model.to(args.device)

                # Evaluate
                result = evaluate(args,
                                  model,
                                  tokenizer,
                                  prefix=global_step,
                                  dataset=dev_dataset,
                                  examples=dev_examples,
                                  features=dev_features)

                result = dict(
                    (k + ('_{}'.format(global_step) if global_step else ''), v)
                    for k, v in result.items())
                results.update(result)

                f_results.write(
                    json.dumps({checkpoint + '_' + args.predict_file: result},
                               indent=2))

            f_results.close()

        logger.info("Results: {}".format(results))

    return results
Пример #7
0
def mtl_train(args, config, train_set, dev_set, label_map, bert_model,
              clf_head):
    save_dir = "./models/{}".format(utils.get_savedir_name())
    tb_writer = SummaryWriter(os.path.join(save_dir, "logs"))

    train_set = ConcatDataset(train_set)
    train_loader = DataLoader(
        dataset=train_set,
        sampler=utils.BalancedTaskSampler(dataset=train_set,
                                          batch_size=config.batch_size),
        batch_size=config.batch_size,
        collate_fn=utils.collate_fn,
        shuffle=False,
        num_workers=0,
    )
    dev_set = ConcatDataset(dev_set)
    dev_loader = DataLoader(
        dataset=dev_set,
        batch_size=config.batch_size,
        collate_fn=utils.collate_fn,
        shuffle=False,
        num_workers=0,
    )
    num_epochs = config.num_epochs

    if not config.finetune_enc:
        for param in bert_model.parameters():
            param.requires_grad = False
        extra = []
    else:
        extra = list(bert_model.named_parameters())

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in list(clf_head.named_parameters()) + extra
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            config.weight_decay,
        },
        {
            "params": [
                p for n, p in list(clf_head.named_parameters()) + extra
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    opt = AdamW(optimizer_grouped_parameters, eps=1e-8, lr=config.outer_lr)

    best_dev_error = np.inf
    if args.load_from:
        state_obj = torch.load(os.path.join(args.load_from, "optim.th"))
        opt.load_state_dict(state_obj["optimizer"])
        num_epochs = num_epochs - state_obj["last_epoch"]
        bert_model = bert_model.eval()
        clf_head = clf_head.eval()
        dev_loss, dev_metrics = utils.compute_loss_metrics(
            dev_loader,
            bert_model,
            clf_head,
            label_map,
            grad_required=False,
            return_metrics=False,
        )
        best_dev_error = dev_loss.mean()

    patience_ctr = 0
    for epoch in range(num_epochs):
        running_loss = 0.0
        epoch_iterator = tqdm(train_loader, desc="Training")
        for (
                train_step,
            (input_ids, attention_mask, token_type_ids, labels, _, _),
        ) in enumerate(epoch_iterator):
            # train
            bert_model.train()
            clf_head.train()
            opt.zero_grad()
            bert_output = bert_model(input_ids, attention_mask, token_type_ids)
            output = clf_head(bert_output,
                              labels=labels,
                              attention_mask=attention_mask)
            loss = output.loss.mean()
            loss.backward()
            if config.finetune_enc:
                torch.nn.utils.clip_grad_norm_(bert_model.parameters(),
                                               config.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(clf_head.parameters(),
                                           config.max_grad_norm)
            opt.step()
            running_loss += loss.item()
            # eval at the beginning of every epoch and after every `config.eval_freq` steps
            if train_step % config.eval_freq == 0:
                bert_model.eval()
                clf_head.eval()
                dev_loss, dev_metrics = utils.compute_loss_metrics(
                    dev_loader,
                    bert_model,
                    clf_head,
                    label_map,
                    grad_required=False,
                    return_metrics=False,
                )
                dev_loss = dev_loss.mean()

                tb_writer.add_scalar("metrics/loss", dev_loss, epoch)
                if dev_metrics is not None:
                    tb_writer.add_scalar("metrics/precision",
                                         dev_metrics["precision"], epoch)
                    tb_writer.add_scalar("metrics/recall",
                                         dev_metrics["recall"], epoch)
                    tb_writer.add_scalar("metrics/f1", dev_metrics["f1"],
                                         epoch)
                    logger.info(
                        "Dev. metrics (p/r/f): {:.3f} {:.3f} {:.3f}".format(
                            dev_metrics["precision"],
                            dev_metrics["recall"],
                            dev_metrics["f1"],
                        ))

                if dev_loss < best_dev_error:
                    logger.info("Found new best model!")
                    best_dev_error = dev_loss
                    save(clf_head, opt, args.config_path, epoch, bert_model)
                    patience_ctr = 0
                else:
                    patience_ctr += 1
                    if patience_ctr == config.patience:
                        logger.info(
                            "Ran out of patience. Stopping training early...")
                        return

        logger.info(
            f"Finished epoch {epoch+1} with avg. training loss: {running_loss/(train_step + 1)}"
        )

    logger.info(f"Best validation loss = {best_dev_error}")
    logger.info("Best model saved at: {}".format(utils.get_savedir_name()))
Пример #8
0
def _get_datasets(dataset, dataroot, load_train:bool, load_test:bool,
        transform_train, transform_test, train_max_size:int, test_max_size:int)\
            ->Tuple[DatasetLike, DatasetLike]:
    logger = get_logger()
    trainset, testset = None, None

    if dataset == 'cifar10':
        if load_train:
            # NOTE: train transforms will also be applied to validation set
            trainset = torchvision.datasets.CIFAR10(root=dataroot,
                                                    train=True,
                                                    download=True,
                                                    transform=transform_train)
        if load_test:
            testset = torchvision.datasets.CIFAR10(root=dataroot,
                                                   train=False,
                                                   download=True,
                                                   transform=transform_test)
    elif dataset == 'mnist':
        if load_train:
            trainset = torchvision.datasets.MNIST(root=dataroot,
                                                  train=True,
                                                  download=True,
                                                  transform=transform_train)
        if load_test:
            testset = torchvision.datasets.MNIST(root=dataroot,
                                                 train=False,
                                                 download=True,
                                                 transform=transform_test)
    elif dataset == 'fashionmnist':
        if load_train:
            trainset = torchvision.datasets.FashionMNIST(
                root=dataroot,
                train=True,
                download=True,
                transform=transform_train)
        if load_test:
            testset = torchvision.datasets.FashionMNIST(
                root=dataroot,
                train=False,
                download=True,
                transform=transform_test)
    elif dataset == 'reduced_cifar10':
        if load_train:
            trainset = torchvision.datasets.CIFAR10(root=dataroot,
                                                    train=True,
                                                    download=True,
                                                    transform=transform_train)
            sss = StratifiedShuffleSplit(n_splits=1, test_size=46000)  # 4000
            sss = sss.split(list(range(len(trainset))), trainset.targets)
            train_idx, valid_idx = next(sss)
            targets = [trainset.targets[idx] for idx in train_idx]
            trainset = Subset(trainset, train_idx)
            trainset.targets = targets
        if load_test:
            testset = torchvision.datasets.CIFAR10(root=dataroot,
                                                   train=False,
                                                   download=True,
                                                   transform=transform_test)
    elif dataset == 'cifar100':
        if load_train:
            trainset = torchvision.datasets.CIFAR100(root=dataroot,
                                                     train=True,
                                                     download=True,
                                                     transform=transform_train)
        if load_test:
            testset = torchvision.datasets.CIFAR100(root=dataroot,
                                                    train=False,
                                                    download=True,
                                                    transform=transform_test)
    elif dataset == 'svhn':
        if load_train:
            trainset = torchvision.datasets.SVHN(root=dataroot,
                                                 split='train',
                                                 download=True,
                                                 transform=transform_train)
            extraset = torchvision.datasets.SVHN(root=dataroot,
                                                 split='extra',
                                                 download=True,
                                                 transform=transform_train)
            trainset = ConcatDataset([trainset, extraset])
        if load_test:
            testset = torchvision.datasets.SVHN(root=dataroot,
                                                split='test',
                                                download=True,
                                                transform=transform_test)
    elif dataset == 'reduced_svhn':
        if load_train:
            trainset = torchvision.datasets.SVHN(root=dataroot,
                                                 split='train',
                                                 download=True,
                                                 transform=transform_train)
            sss = StratifiedShuffleSplit(n_splits=1,
                                         test_size=73257 - 1000)  #1000
            sss = sss.split(list(range(len(trainset))), trainset.targets)
            train_idx, valid_idx = next(sss)
            targets = [trainset.targets[idx] for idx in train_idx]
            trainset = Subset(trainset, train_idx)
            trainset.targets = targets
        if load_test:
            testset = torchvision.datasets.SVHN(root=dataroot,
                                                split='test',
                                                download=True,
                                                transform=transform_test)
    elif dataset == 'imagenet':
        if load_train:
            trainset = ImageNet(root=os.path.join(dataroot,
                                                  'imagenet-pytorch'),
                                transform=transform_train)
            # compatibility
            trainset.targets = [lb for _, lb in trainset.samples]
        if load_test:
            testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'),
                               split='val',
                               transform=transform_test)
    elif dataset == 'reduced_imagenet':
        # randomly chosen indices
        idx120 = [
            904, 385, 759, 884, 784, 844, 132, 214, 990, 786, 979, 582, 104,
            288, 697, 480, 66, 943, 308, 282, 118, 926, 882, 478, 133, 884,
            570, 964, 825, 656, 661, 289, 385, 448, 705, 609, 955, 5, 703, 713,
            695, 811, 958, 147, 6, 3, 59, 354, 315, 514, 741, 525, 685, 673,
            657, 267, 575, 501, 30, 455, 905, 860, 355, 911, 24, 708, 346, 195,
            660, 528, 330, 511, 439, 150, 988, 940, 236, 803, 741, 295, 111,
            520, 856, 248, 203, 147, 625, 589, 708, 201, 712, 630, 630, 367,
            273, 931, 960, 274, 112, 239, 463, 355, 955, 525, 404, 59, 981,
            725, 90, 782, 604, 323, 418, 35, 95, 97, 193, 690, 869, 172
        ]
        if load_train:
            trainset = ImageNet(root=os.path.join(dataroot,
                                                  'imagenet-pytorch'),
                                transform=transform_train)
            # compatibility
            trainset.targets = [lb for _, lb in trainset.samples]

            sss = StratifiedShuffleSplit(n_splits=1,
                                         test_size=len(trainset) - 500000,
                                         random_state=0)  # 4000
            sss = sss.split(list(range(len(trainset))), trainset.targets)
            train_idx, valid_idx = next(sss)

            # filter out
            train_idx = list(
                filter(lambda x: trainset.labels[x] in idx120, train_idx))
            valid_idx = list(
                filter(lambda x: trainset.labels[x] in idx120, valid_idx))

            targets = [
                idx120.index(trainset.targets[idx]) for idx in train_idx
            ]
            for idx in range(len(trainset.samples)):
                if trainset.samples[idx][1] not in idx120:
                    continue
                trainset.samples[idx] = (trainset.samples[idx][0],
                                         idx120.index(
                                             trainset.samples[idx][1]))
            trainset = Subset(trainset, train_idx)
            trainset.targets = targets
        if load_test:
            testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'),
                               split='val',
                               transform=transform_test)
            test_idx = list(filter(lambda x: testset.samples[x][1] in \
                idx120, range(len(testset))))
            for idx in range(len(testset.samples)):
                if testset.samples[idx][1] not in idx120:
                    continue
                testset.samples[idx] = (testset.samples[idx][0],
                                        idx120.index(testset.samples[idx][1]))
            testset = Subset(testset, test_idx)
    else:
        raise ValueError('invalid dataset name=%s' % dataset)

    if train_max_size > 0:
        logger.warn({'train_max_batches': train_max_size})
        trainset = LimitDataset(trainset, train_max_size)
    if test_max_size > 0:
        logger.warn({'test_max_batches': test_max_size})
        testset = LimitDataset(testset, test_max_size)

    return trainset, testset
Пример #9
0
    def load_dataset(self, split, combine=False):
        """Load a dataset split."""

        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            token_path = os.path.join(self.args.data, split_k)

            if IndexedInMemoryDataset.exists(token_path):
                token_ds = IndexedInMemoryDataset(token_path,
                                                  fix_lua_indexing=True)
                tokens = token_ds.buffer

                sizes = token_ds.sizes

                in_tsv_file_path = os.path.join(self.args.data,
                                                f'gap-{split}.bert.tsv')
                gap_reader = GAP_Reader(in_tsv_file_path, is_gold=True)
                gap_data = gap_reader.read()

                in_bert_file_path = os.path.join(self.args.data,
                                                 f'gap-{split}.bert.jsonl')

                gap_bert_reader = Bert_Reader(in_bert_file_path)
                gap_bert_data = gap_bert_reader.read()
                gap_bert_weights = [
                    bert_weights for _, bert_weights in gap_bert_data
                ]

                gap_texts = [d.text.split() for d in gap_data]
                assert np.array_equal(sizes, [len(t) + 1 for t in gap_texts])
                assert np.array_equal(
                    sizes,
                    [len(bert_tokens) + 1 for bert_tokens, _ in gap_bert_data])
                assert np.array_equal(
                    [d.text.split(" ") for d in gap_data],
                    [bert_tokens for bert_tokens, _ in gap_bert_data])

                gap_corefs = self.generate_gap_coref_supervision(
                    gap_data, sizes)
                assert len(gap_data) == len(gap_corefs)

            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))

            loaded_datasets.append(
                TokenBlockGapBertDataset(
                    tokens,
                    sizes,
                    self.args.tokens_per_sample,
                    gap_data,
                    gap_corefs,
                    gap_bert_weights,
                    break_mode=self.args.sample_break_mode,
                    include_targets=True))

            if split == "train":
                gap_dataset = TokenBlockGapBertDataset(
                    tokens,
                    sizes,
                    self.args.tokens_per_sample,
                    gap_data,
                    gap_corefs,
                    gap_bert_weights,
                    self.args.sample_break_mode,
                    include_targets=True)
                self.datasets["train_gap_only"] = MonolingualGapBertDataset(
                    gap_dataset,
                    gap_dataset.sizes,
                    self.token_dictionary,
                    shuffle=False)

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MonolingualGapBertDataset(dataset,
                                                         sizes,
                                                         self.token_dictionary,
                                                         shuffle=False)
Пример #10
0
    def _make_question_answering(cls, datasilo, sets=["train", "dev", "test"], n_splits=5, shuffle=True,
                                 random_state=None, n_neg_answers_per_question=1):
        """
        Create number of folds data-silo-like objects which can be used for training from the
        original data silo passed on. This function takes into account the characteristics of the
        data for question-answering-

        :param datasilo: the data silo that contains the original data
        :type datasilo: DataSilo
        :param sets: which sets to use to create the xval folds (strings)
        :type sets: list
        :param n_splits: number of folds to create
        :type n_splits: int
        :param shuffle: shuffle each class' samples before splitting
        :type shuffle: bool
        :param random_state: random state for shuffling
        :type random_state: int
        :param n_neg_answers_per_question: number of negative answers per question to include for training
        :type n_neg_answers_per_question: int
        """
        assert "id" in datasilo.tensor_names, f"Expected tensor 'id' in tensor names, found {datasilo.tensor_names}"
        assert "labels" in datasilo.tensor_names, f"Expected tensor 'labels' in tensor names, found {datasilo.tensor_names}"

        id_index = datasilo.tensor_names.index("id")
        label_index = datasilo.tensor_names.index("labels")

        sets_to_concat = []
        for setname in sets:
            if datasilo.data[setname]:
                sets_to_concat.extend(datasilo.data[setname])
        all_data = ConcatDataset(sets_to_concat)

        documents = []
        keyfunc = lambda x: x[id_index][0]
        all_data = sorted(all_data.datasets, key=keyfunc)
        for key, document in groupby(all_data, key=keyfunc):
            documents.append(list(document))

        xval_split = cls._split_for_qa(documents = documents,
                                       id_index=id_index,
                                       n_splits=n_splits,
                                       shuffle=shuffle,
                                       random_state=random_state,
                                       )
        silos = []

        for train_set, test_set in xval_split:
            # Each training set is further divided into actual train and dev set
            if datasilo.processor.dev_split > 0:
                dev_split = datasilo.processor.dev_split
                n_dev = int(np.ceil(dev_split * len(train_set)))
                assert n_dev > 0, f"dev split of {dev_split} is not large enough to split away a development set"
                n_actual_train = len(train_set) - n_dev
                actual_train_set = train_set[:n_actual_train]
                dev_set = train_set[n_actual_train:]
                ds_dev = [sample for document in dev_set for sample in document]
            else:
                ds_dev = None
                actual_train_set = train_set

            train_samples = []
            for doc in actual_train_set:
                keyfunc = lambda x: x[id_index][1]
                doc = sorted(doc, key=keyfunc)
                for key, question in groupby(doc, key=keyfunc):
                    # add all available answrs to train set
                    sample_list = list(question)
                    neg_answer_idx = []
                    for index, sample in enumerate(sample_list):
                        if sample[label_index][0][0] or sample[label_index][0][1]:
                            train_samples.append(sample)
                        else:
                            neg_answer_idx.append(index)
                    # add random n_neg_answers_per_question samples to train set
                    if len(neg_answer_idx) <= n_neg_answers_per_question:
                        train_samples.extend([sample_list[idx] for idx in neg_answer_idx])
                    else:
                        neg_answer_idx = random.sample(neg_answer_idx, n_neg_answers_per_question)
                        train_samples.extend([sample_list[idx] for idx in neg_answer_idx])

            ds_train = train_samples
            ds_test = [sample for document in test_set for sample in document]
            silos.append(DataSiloForCrossVal(datasilo, ds_train, ds_dev, ds_test))
        return silos
        test_ma_f1 = f1_score(y_true, y_pred, average='macro')
        test_accuracy = correct/total
        writer.add_scalar('accuracy_test/accuracy_test', test_accuracy, global_batch_counter_test)
        writer.add_scalar('accuracy_test/micro_f1_test', test_mi_f1, global_batch_counter_test)
        writer.add_scalar('accuracy_test/macro_f1_test', test_ma_f1, global_batch_counter_test)
        print('[%d] epoch, [%.3f] training loss, [%.3f] testing loss, [%.3f] testing accuracy'
              %(epoch, loss_train/len(train_loader), loss_test/test_idx, test_accuracy))

        print('Saving models and optimizer...')
        save_model_optimizer(model, optimizer, epoch, global_batch_counter_train, global_batch_counter_test, dir)
        print('Saved!')

        if bootstrapping_start_epoch and topk_pre_class*num_labels <= len(test_set) * bootstrapping_max_usage:
            print('Bootstrapping...')
            extra_train_set = get_topk(logits, topk_pre_class, y_pred, test_set)
            train_set = ConcatDataset([raw_train_set, extra_train_set])
            train_loader = DataLoader(train_set, sampler=RandomSampler(train_set), batch_size=batch_size_train)
            topk_pre_class = int(len(train_set)//num_labels * bootstrapping_increase_coef)

        loss_train = 0
        model.train()

    elif not save_freq_n_epoch:
        print('Saving models and optimizer...')
        save_model_optimizer(model, optimizer, epoch, global_batch_counter_train, global_batch_counter_test, dir)
        print('Saved!')

    if DEBUG:
        break

writer.close()
Пример #12
0
def train_cl(model,
             train_datasets,
             replay_mode="none",
             scenario="task",
             rnt=None,
             classes_per_task=None,
             iters=2000,
             batch_size=32,
             batch_size_replay=None,
             loss_cbs=list(),
             eval_cbs=list(),
             sample_cbs=list(),
             generator=None,
             gen_iters=0,
             gen_loss_cbs=list(),
             feedback=False,
             reinit=False,
             args=None,
             only_last=False):
    '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode].

    [model]             <nn.Module> main model to optimize across all tasks
    [train_datasets]    <list> with for each task the training <DataSet>
    [replay_mode]       <str>, choice from "generative", "current", "offline" and "none"
    [scenario]          <str>, choice from "task", "domain", "class" and "all"
    [classes_per_task]  <int>, # classes per task; only 1st task has [classes_per_task]*[first_task_class_boost] classes
    [rnt]               <float>, indicating relative importance of new task (if None, relative to # old tasks)
    [iters]             <int>, # optimization-steps (=batches) per task; 1st task has [first_task_iter_boost] steps more
    [batch_size_replay] <int>, number of samples to replay per batch
    [generator]         None or <nn.Module>, if a seperate generative model should be trained (for [gen_iters] per task)
    [feedback]          <bool>, if True and [replay_mode]="generative", the main model is used for generating replay
    [only_last]         <bool>, only train on final task / episode
    [*_cbs]             <list> of call-back functions to evaluate training-progress'''

    # Should convolutional layers be frozen?
    freeze_convE = (utils.checkattr(args, "freeze_convE")
                    and hasattr(args, "depth") and args.depth > 0)

    # Use cuda?
    device = model._device()
    cuda = model._is_on_cuda()

    # Set default-values if not specified
    batch_size_replay = batch_size if batch_size_replay is None else batch_size_replay

    # Initiate indicators for replay (no replay for 1st task)
    Generative = Current = Offline_TaskIL = False
    previous_model = None

    # Register starting param-values (needed for "intelligent synapses").
    if isinstance(model, ContinualLearner) and model.si_c > 0:
        for n, p in model.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '__')
                model.register_buffer('{}_SI_prev_task'.format(n),
                                      p.detach().clone())

    # Loop over all tasks.
    for task, train_dataset in enumerate(train_datasets, 1):

        # If offline replay-setting, create large database of all tasks so far
        if replay_mode == "offline" and (not scenario == "task"):
            train_dataset = ConcatDataset(train_datasets[:task])
        # -but if "offline"+"task": all tasks so far should be visited separately (i.e., separate data-loader per task)
        if replay_mode == "offline" and scenario == "task":
            Offline_TaskIL = True
            data_loader = [None] * task

        # Initialize # iters left on data-loader(s)
        iters_left = 1 if (not Offline_TaskIL) else [1] * task

        # Prepare <dicts> to store running importance estimates and parameter-values before update
        if isinstance(model, ContinualLearner) and model.si_c > 0:
            W = {}
            p_old = {}
            for n, p in model.named_parameters():
                if p.requires_grad:
                    n = n.replace('.', '__')
                    W[n] = p.data.clone().zero_()
                    p_old[n] = p.data.clone()

        # Find [active_classes] (=classes in current task)
        active_classes = None  #-> for "domain"- or "all"-scenarios, always all classes are active
        if scenario == "task":
            # -for "task"-scenario, create <list> with for all tasks so far a <list> with the active classes
            active_classes = [
                list(range(classes_per_task * i, classes_per_task * (i + 1)))
                for i in range(task)
            ]
        elif scenario == "class":
            # -for "class"-scenario, create one <list> with active classes of all tasks so far
            active_classes = list(range(classes_per_task * task))

        # Reinitialize the model's parameters (if requested)
        if reinit:
            from define_models import init_params
            init_params(model, args)
            if generator is not None:
                init_params(generator, args)

        # Define a tqdm progress bar(s)
        iters_main = iters
        progress = tqdm.tqdm(range(1, iters_main + 1))
        if generator is not None:
            iters_gen = gen_iters
            progress_gen = tqdm.tqdm(range(1, iters_gen + 1))

        # Loop over all iterations
        iters_to_use = (iters_main if
                        (generator is None) else max(iters_main, iters_gen))
        # -if only the final task should be trained on:
        if only_last and not task == len(train_datasets):
            iters_to_use = 0
        for batch_index in range(1, iters_to_use + 1):

            # Update # iters left on current data-loader(s) and, if needed, create new one(s)
            if not Offline_TaskIL:
                iters_left -= 1
                if iters_left == 0:
                    data_loader = iter(
                        utils.get_data_loader(train_dataset,
                                              batch_size,
                                              cuda=cuda,
                                              drop_last=True))
                    iters_left = len(data_loader)
            else:
                # -with "offline replay" in Task-IL scenario, there is a separate data-loader for each task
                batch_size_to_use = int(np.ceil(batch_size / task))
                for task_id in range(task):
                    iters_left[task_id] -= 1
                    if iters_left[task_id] == 0:
                        data_loader[task_id] = iter(
                            utils.get_data_loader(train_datasets[task_id],
                                                  batch_size_to_use,
                                                  cuda=cuda,
                                                  drop_last=True))
                        iters_left[task_id] = len(data_loader[task_id])

            #-----------------Collect data------------------#

            #####-----CURRENT BATCH-----#####
            if not Offline_TaskIL:
                x, y = next(
                    data_loader)  #--> sample training data of current task
                y = y - classes_per_task * (
                    task - 1
                ) if scenario == "task" else y  #--> ITL: adjust y-targets to 'active range'
                x, y = x.to(device), y.to(
                    device)  #--> transfer them to correct device
                #y = y.expand(1) if len(y.size())==1 else y                 #--> hack for if batch-size is 1
            else:
                x = y = task_used = None  #--> all tasks are "treated as replay"
                # -sample training data for all tasks so far, move to correct device and store in lists
                x_, y_ = list(), list()
                for task_id in range(task):
                    x_temp, y_temp = next(data_loader[task_id])
                    x_.append(x_temp.to(device))
                    y_temp = y_temp - (
                        classes_per_task * task_id
                    )  #--> adjust y-targets to 'active range'
                    if batch_size_to_use == 1:
                        y_temp = torch.tensor([
                            y_temp
                        ])  #--> correct dimensions if batch-size is 1
                    y_.append(y_temp.to(device))

            #####-----REPLAYED BATCH-----#####
            if not Offline_TaskIL and not Generative and not Current:
                x_ = y_ = scores_ = task_used = None  #-> if no replay

            #--------------------------------------------INPUTS----------------------------------------------------#

            ##-->> Current Replay <<--##
            if Current:
                x_ = x[:batch_size_replay]  #--> use current task inputs
                task_used = None

            ##-->> Generative Replay <<--##
            if Generative:
                #---> Only with generative replay, the resulting [x_] will be at the "hidden"-level
                conditional_gen = True if (
                    (previous_generator.per_class
                     and previous_generator.prior == "GMM") or utils.checkattr(
                         previous_generator, 'dg_gates')) else False

                # Sample [x_]
                if conditional_gen and scenario == "task":
                    # -if a conditional generator is used with task-IL scenario, generate data per previous task
                    x_ = list()
                    task_used = list()
                    for task_id in range(task - 1):
                        allowed_classes = list(
                            range(classes_per_task * task_id,
                                  classes_per_task * (task_id + 1)))
                        batch_size_replay_to_use = int(
                            np.ceil(batch_size_replay / (task - 1)))
                        x_temp_ = previous_generator.sample(
                            batch_size_replay_to_use,
                            allowed_classes=allowed_classes,
                            only_x=False)
                        x_.append(x_temp_[0])
                        task_used.append(x_temp_[2])
                else:
                    # -which classes are allowed to be generated? (relevant if conditional generator / decoder-gates)
                    allowed_classes = None if scenario == "domain" else list(
                        range(classes_per_task * (task - 1)))
                    # -which tasks/domains are allowed to be generated? (only relevant if "Domain-IL" with task-gates)
                    allowed_domains = list(range(task))
                    # -generate inputs representative of previous tasks
                    x_temp_ = previous_generator.sample(
                        batch_size_replay,
                        allowed_classes=allowed_classes,
                        allowed_domains=allowed_domains,
                        only_x=False,
                    )
                    x_ = x_temp_[0]
                    task_used = x_temp_[2]

            #--------------------------------------------OUTPUTS----------------------------------------------------#

            if Generative or Current:
                # Get target scores & possibly labels (i.e., [scores_] / [y_]) -- use previous model, with no_grad()
                if scenario in ("domain",
                                "class") and previous_model.mask_dict is None:
                    # -if replay does not need to be evaluated for each task (ie, not Task-IL and no task-specific mask)
                    with torch.no_grad():
                        all_scores_ = previous_model.classify(
                            x_, not_hidden=False if Generative else True)
                    scores_ = all_scores_[:, :(
                        classes_per_task * (task - 1)
                    )] if (
                        scenario == "class"
                    ) else all_scores_  # -> when scenario=="class", zero probs will be added in [loss_fn_kd]-function
                    # -also get the 'hard target'
                    _, y_ = torch.max(scores_, dim=1)
                else:
                    # -[x_] needs to be evaluated according to each previous task, so make list with entry per task
                    scores_ = list()
                    y_ = list()
                    # -if no task-mask and no conditional generator, all scores can be calculated in one go
                    if previous_model.mask_dict is None and not type(
                            x_) == list:
                        with torch.no_grad():
                            all_scores_ = previous_model.classify(
                                x_, not_hidden=False if Generative else True)
                    for task_id in range(task - 1):
                        # -if there is a task-mask (i.e., XdG is used), obtain predicted scores for each task separately
                        if previous_model.mask_dict is not None:
                            previous_model.apply_XdGmask(task=task_id + 1)
                        if previous_model.mask_dict is not None or type(
                                x_) == list:
                            with torch.no_grad():
                                all_scores_ = previous_model.classify(
                                    x_[task_id] if type(x_) == list else x_,
                                    not_hidden=False if Generative else True)
                        if scenario == "domain":
                            # NOTE: if scenario=domain with task-mask, it's of course actually the Task-IL scenario!
                            #       this can be used as trick to run the Task-IL scenario with singlehead output layer
                            temp_scores_ = all_scores_
                        else:
                            temp_scores_ = all_scores_[:, (
                                classes_per_task * task_id):(classes_per_task *
                                                             (task_id + 1))]
                        scores_.append(temp_scores_)
                        # - also get hard target
                        _, temp_y_ = torch.max(temp_scores_, dim=1)
                        y_.append(temp_y_)
            # -only keep predicted y_/scores_ if required (as otherwise unnecessary computations will be done)
            y_ = y_ if (model.replay_targets == "hard") else None
            scores_ = scores_ if (model.replay_targets == "soft") else None

            #-----------------Train model(s)------------------#

            #---> Train MAIN MODEL
            if batch_index <= iters_main:

                # Train the main model with this batch
                loss_dict = model.train_a_batch(
                    x,
                    y=y,
                    x_=x_,
                    y_=y_,
                    scores_=scores_,
                    tasks_=task_used,
                    active_classes=active_classes,
                    task=task,
                    rnt=(1. if task == 1 else 1. /
                         task) if rnt is None else rnt,
                    freeze_convE=freeze_convE,
                    replay_not_hidden=False if Generative else True)

                # Update running parameter importance estimates in W
                if isinstance(model, ContinualLearner) and model.si_c > 0:
                    for n, p in model.convE.named_parameters():
                        if p.requires_grad:
                            n = "convE." + n
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad * (p.detach() - p_old[n]))
                            p_old[n] = p.detach().clone()
                    for n, p in model.fcE.named_parameters():
                        if p.requires_grad:
                            n = "fcE." + n
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad * (p.detach() - p_old[n]))
                            p_old[n] = p.detach().clone()
                    for n, p in model.classifier.named_parameters():
                        if p.requires_grad:
                            n = "classifier." + n
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad * (p.detach() - p_old[n]))
                            p_old[n] = p.detach().clone()

                # Fire callbacks (for visualization of training-progress / evaluating performance after each task)
                for loss_cb in loss_cbs:
                    if loss_cb is not None:
                        loss_cb(progress, batch_index, loss_dict, task=task)
                for eval_cb in eval_cbs:
                    if eval_cb is not None:
                        eval_cb(model, batch_index, task=task)
                if model.label == "VAE":
                    for sample_cb in sample_cbs:
                        if sample_cb is not None:
                            sample_cb(model,
                                      batch_index,
                                      task=task,
                                      allowed_classes=None if
                                      (scenario == "domain") else list(
                                          range(classes_per_task * task)))

            #---> Train GENERATOR
            if generator is not None and batch_index <= iters_gen:

                loss_dict = generator.train_a_batch(
                    x,
                    y=y,
                    x_=x_,
                    y_=y_,
                    scores_=scores_,
                    tasks_=task_used,
                    active_classes=active_classes,
                    rnt=(1. if task == 1 else 1. /
                         task) if rnt is None else rnt,
                    task=task,
                    freeze_convE=freeze_convE,
                    replay_not_hidden=False if Generative else True)

                # Fire callbacks on each iteration
                for loss_cb in gen_loss_cbs:
                    if loss_cb is not None:
                        loss_cb(progress_gen,
                                batch_index,
                                loss_dict,
                                task=task)
                for sample_cb in sample_cbs:
                    if sample_cb is not None:
                        sample_cb(generator,
                                  batch_index,
                                  task=task,
                                  allowed_classes=None if
                                  (scenario == "domain") else list(
                                      range(classes_per_task * task)))

        # Close progres-bar(s)
        progress.close()
        if generator is not None:
            progress_gen.close()

        ##----------> UPON FINISHING EACH TASK...

        # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty
        if isinstance(model, ContinualLearner) and model.ewc_lambda > 0:
            # -find allowed classes
            allowed_classes = list(
                range(classes_per_task * (task - 1), classes_per_task *
                      task)) if scenario == "task" else (
                          list(range(classes_per_task *
                                     task)) if scenario == "class" else None)
            # -if needed, apply correct task-specific mask
            if model.mask_dict is not None:
                model.apply_XdGmask(task=task)
            # -estimate FI-matrix
            model.estimate_fisher(train_dataset,
                                  allowed_classes=allowed_classes)

        # SI: calculate and update the normalized path integral
        if isinstance(model, ContinualLearner) and model.si_c > 0:
            model.update_omega(W, model.epsilon)

        # REPLAY: update source for replay
        previous_model = copy.deepcopy(model).eval()
        if replay_mode == "generative":
            Generative = True
            previous_generator = previous_model if feedback else copy.deepcopy(
                generator).eval()
        elif replay_mode == 'current':
            Current = True
Пример #13
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, "log"))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, "ckpt"))
        add_log_to_file(join(opts.output_dir, "log", "log.txt"))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, "results_val"))
        os.makedirs(join(opts.output_dir, "results_test"))
        os.makedirs(join(opts.output_dir, "results_train"))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(
        opts.train_img_dbs), "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets.append(
            ItmRankDataset(txt_db, img_db, opts.negative_size))
    train_dataset = ConcatDataset(train_datasets)

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)
    test_img_db = all_img_dbs[opts.test_img_db]
    test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
    eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                       opts.inf_minibatch_size)
    eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate,
                                        False, opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = UniterForImageTextRetrieval.from_pretrained(opts.model_config,
                                                        state_dict=checkpoint,
                                                        img_dim=IMG_DIM,
                                                        margin=opts.margin)
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level="O2")

    global_step = 0
    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter("loss")
    model.train()

    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        train_dataloader = build_dataloader(train_dataset, itm_rank_collate,
                                            True, opts)
        for step, batch in enumerate(train_dataloader):
            n_examples += batch["input_ids"].size(0)
            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr_this_step
                TB_LOGGER.add_scalar("lr", lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar("loss", running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar("grad_norm", grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f"------------Step {global_step}-------------")
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f"{tot_ex} examples trained at "
                                f"{ex_per_sec} ex/s")
                    TB_LOGGER.add_scalar("perf/ex_per_s", ex_per_sec,
                                         global_step)
                    LOGGER.info(f"-------------------------------------------")

                if global_step % opts.valid_steps == 0:
                    if opts.full_val:
                        LOGGER.info(
                            f"========================== Step {global_step} "
                            f"==========================")
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        LOGGER.info(f"image retrieval R1: "
                                    f"{val_log['img_r1']*100:.2f},\n"
                                    f"image retrieval R5: "
                                    f"{val_log['img_r5']*100:.2f},\n"
                                    f"image retrieval R10: "
                                    f"{val_log['img_r10']*100:.2f}\n"
                                    f"text retrieval R1: "
                                    f"{val_log['txt_r1']*100:.2f},\n"
                                    f"text retrieval R5: "
                                    f"{val_log['txt_r5']*100:.2f},\n"
                                    f"text retrieval R10: "
                                    f"{val_log['txt_r10']*100:.2f}")
                        LOGGER.info("================================="
                                    "=================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")

    pbar.close()
    if opts.num_train_steps % opts.valid_steps != 0:
        # final validation
        val_log = validate(model, val_dataloader)
        TB_LOGGER.log_scaler_dict(val_log)
        model_saver.save(model, global_step)

    # evaluation
    for split, loader in [("val", eval_loader_val),
                          ("test", eval_loader_test)]:
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")
Пример #14
0
def train_cl(model, train_datasets, replay_mode="none", scenario="class",classes_per_task=None,iters=2000,batch_size=32,
             generator=None, gen_iters=0, gen_loss_cbs=list(), loss_cbs=list(), eval_cbs=list(), sample_cbs=list(),
             use_exemplars=True, add_exemplars=False, metric_cbs=list(), buffer_size=1000, valid_datasets=None, early_stop=False, validation=False):
    '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode].

    [model]             <nn.Module> main model to optimize across all tasks
    [train_datasets]    <list> with for each task the training <DataSet>
    [replay_mode]       <str>, choice from "generative", "exact", "current", "offline" and "none"
    [scenario]          <str>, choice from "task", "domain" and "class"
    [classes_per_task]  <int>, # of classes per task
    [iters]             <int>, # of optimization-steps (i.e., # of batches) per task
    [generator]         None or <nn.Module>, if a seperate generative model should be trained (for [gen_iters] per task)
    [*_cbs]             <list> of call-back functions to evaluate training-progress'''

    peak_ramu = ramu.compute("TRAINING")
    valid_precs = []
    train_precs = []
    # Set model in training-mode
    model.train()

    # Use cuda?
    cuda = model._is_on_cuda()
    device = model._device()

    # Initiate possible sources for replay (no replay for 1st task)
    Exact = Generative = Current = False
    previous_model = None

    if replay_mode == "naive-rehearsal": 
        replay_buffer = ReplayBuffer(size=buffer_size, scenario=scenario)

    # Register starting param-values (needed for "intelligent synapses").
    if isinstance(model, ContinualLearner) and (model.si_c>0):
        for n, p in model.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '__')
                model.register_buffer('{}_SI_prev_task'.format(n), p.data.clone())

    # Loop over all tasks.
    for task, train_dataset in enumerate(train_datasets, 1):
        prev_prec = 0.0
        peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))
        
        # If offline replay-setting, create large database of all tasks so far
        if replay_mode=="offline" and (not scenario=="task"):
            train_dataset = ConcatDataset(train_datasets[:task])
        # -but if "offline"+"task"-scenario: all tasks so far included in 'exact replay' & no current batch
        if replay_mode=="offline" and scenario == "task":
            Exact = True
            previous_datasets = train_datasets

        # Add exemplars (if available) to current dataset (if requested)
        if add_exemplars and task>1:
            target_transform = (lambda y, x=classes_per_task: y%x) if scenario=="domain" else None
            exemplar_dataset = ExemplarDataset(model.exemplar_sets, target_transform=target_transform)
            training_dataset = ConcatDataset([train_dataset, exemplar_dataset])
        else:
            training_dataset = train_dataset
        
        # Prepare <dicts> to store running importance estimates and param-values before update ("Synaptic Intelligence")
        if isinstance(model, ContinualLearner) and (model.si_c>0):
            W = {}
            p_old = {}
            for n, p in model.named_parameters():
                if p.requires_grad:
                    n = n.replace('.', '__')
                    W[n] = p.data.clone().zero_()
                    p_old[n] = p.data.clone()

        # Find [active_classes]
        active_classes = None  # -> for Domain-IL scenario, always all classes are active
        if scenario == "task":
            # -for Task-IL scenario, create <list> with for all tasks so far a <list> with the active classes
            active_classes = [list(range(classes_per_task * i, classes_per_task * (i + 1))) for i in range(task)]
        elif scenario == "class":
            # -for Class-IL scenario, create one <list> with active classes of all tasks so far
            active_classes = list(range(classes_per_task * task))

        # Reset state of optimizer(s) for every task (if requested)
        if model.optim_type=="adam_reset":
            model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))
        if (generator is not None) and generator.optim_type=="adam_reset":
            generator.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))

        # Initialize # iters left on current data-loader(s)
        iters_left = iters_left_previous = 1
        if scenario=="task":
            up_to_task = task if replay_mode=="offline" else task-1
            iters_left_previous = [1]*up_to_task
            data_loader_previous = [None]*up_to_task

        # Define tqdm progress bar(s)
        progress = tqdm.tqdm(range(1, iters+1))
        if generator is not None:
            progress_gen = tqdm.tqdm(range(1, gen_iters+1))

        # Loop over all iterations
        iters_to_use = iters if (generator is None) else max(iters, gen_iters)
        peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))
        for batch_index in range(1, iters_to_use+1):

            # Update # iters left on current data-loader(s) and, if needed, create new one(s)
            iters_left -= 1
            if iters_left==0:
                data_loader = iter(utils.get_data_loader(training_dataset, batch_size, cuda=cuda, drop_last=True))
                # NOTE:  [train_dataset]  is training-set of current task
                #      [training_dataset] is training-set of current task with stored exemplars added (if requested)
                iters_left = len(data_loader)
            if Exact:
                if scenario=="task":
                    up_to_task = task if replay_mode=="offline" else task-1
                    batch_size_replay = int(np.ceil(batch_size/up_to_task)) if (up_to_task>1) else batch_size
                    # -in Task-IL scenario, need separate replay for each task
                    for task_id in range(up_to_task):
                        batch_size_to_use = min(batch_size_replay, len(previous_datasets[task_id]))
                        iters_left_previous[task_id] -= 1
                        if iters_left_previous[task_id]==0:
                            data_loader_previous[task_id] = iter(utils.get_data_loader(
                                train_datasets[task_id], batch_size_to_use, cuda=cuda, drop_last=True
                            ))
                            iters_left_previous[task_id] = len(data_loader_previous[task_id])
                else:
                    iters_left_previous -= 1
                    if iters_left_previous==0:
                        batch_size_to_use = min(batch_size, len(ConcatDataset(previous_datasets)))
                        data_loader_previous = iter(utils.get_data_loader(ConcatDataset(previous_datasets),
                                                                          batch_size_to_use, cuda=cuda, drop_last=True))
                        iters_left_previous = len(data_loader_previous)


            # -----------------Collect data------------------#

            peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))
            #####-----CURRENT BATCH-----#####
            if replay_mode=="offline" and scenario=="task":
                x = y = scores = None
            else:
                x, y = next(data_loader)                                    #--> sample training data of current task
                y = y-classes_per_task*(task-1) if scenario=="task" else y  #--> ITL: adjust y-targets to 'active range'
                x, y = x.to(device), y.to(device)                           #--> transfer them to correct device
                # If --bce, --bce-distill & scenario=="class", calculate scores of current batch with previous model
                binary_distillation = hasattr(model, "binaryCE") and model.binaryCE and model.binaryCE_distill
                if binary_distillation and scenario=="class" and (previous_model is not None):
                    with torch.no_grad():
                        scores = previous_model(x)[:, :(classes_per_task * (task - 1))]
                else:
                    scores = None
            peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))

            #####-----REPLAYED BATCH-----#####
            if not Exact and not Generative and not Current:
                x_ = y_ = scores_ = None   #-> if no replay

            ##-->> Exact Replay <<--##
            if Exact:
                scores_ = None
                if scenario in ("domain", "class"):
                    # Sample replayed training data, move to correct device
                    x_, y_ = next(data_loader_previous)
                    x_ = x_.to(device)
                    y_ = y_.to(device) if (model.replay_targets=="hard") else None
                    # If required, get target scores (i.e, [scores_]         -- using previous model, with no_grad()
                    if (model.replay_targets=="soft"):
                        with torch.no_grad():
                            scores_ = previous_model(x_)
                        scores_ = scores_[:, :(classes_per_task*(task-1))] if scenario=="class" else scores_
                        #-> when scenario=="class", zero probabilities will be added in the [utils.loss_fn_kd]-function
                elif scenario=="task":
                    # Sample replayed training data, wrap in (cuda-)Variables and store in lists
                    x_ = list()
                    y_ = list()
                    up_to_task = task if replay_mode=="offline" else task-1
                    for task_id in range(up_to_task):
                        x_temp, y_temp = next(data_loader_previous[task_id])
                        x_.append(x_temp.to(device))
                        # -only keep [y_] if required (as otherwise unnecessary computations will be done)
                        if model.replay_targets=="hard":
                            y_temp = y_temp - (classes_per_task*task_id) #-> adjust y-targets to 'active range'
                            y_.append(y_temp.to(device))
                        else:
                            y_.append(None)
                    # If required, get target scores (i.e, [scores_]         -- using previous model
                    if (model.replay_targets=="soft") and (previous_model is not None):
                        scores_ = list()
                        for task_id in range(up_to_task):
                            with torch.no_grad():
                                scores_temp = previous_model(x_[task_id])
                            scores_temp = scores_temp[:, (classes_per_task*task_id):(classes_per_task*(task_id+1))]
                            scores_.append(scores_temp)

            ##-->> Generative / Current Replay <<--##
            if Generative or Current:
                # Get replayed data (i.e., [x_]) -- either current data or use previous generator
                x_ = x if Current else previous_generator.sample(batch_size)
                peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))
                # Get target scores and labels (i.e., [scores_] / [y_]) -- using previous model, with no_grad()
                # -if there are no task-specific mask, obtain all predicted scores at once
                if (not hasattr(previous_model, "mask_dict")) or (previous_model.mask_dict is None):
                    with torch.no_grad():
                        all_scores_ = previous_model(x_)
                # -depending on chosen scenario, collect relevant predicted scores (per task, if required)
                if scenario in ("domain", "class") and (
                        (not hasattr(previous_model, "mask_dict")) or (previous_model.mask_dict is None)
                ):
                    scores_ = all_scores_[:,:(classes_per_task * (task - 1))] if scenario == "class" else all_scores_
                    _, y_ = torch.max(scores_, dim=1)
                else:
                    # NOTE: it's possible to have scenario=domain with task-mask (so actually it's the Task-IL scenario)
                    # -[x_] needs to be evaluated according to each previous task, so make list with entry per task
                    scores_ = list()
                    y_ = list()
                    for task_id in range(task - 1):
                        # -if there is a task-mask (i.e., XdG is used), obtain predicted scores for each task separately
                        if hasattr(previous_model, "mask_dict") and previous_model.mask_dict is not None:
                            previous_model.apply_XdGmask(task=task_id + 1)
                            with torch.no_grad():
                                all_scores_ = previous_model(x_)
                        if scenario=="domain":
                            temp_scores_ = all_scores_
                        else:
                            temp_scores_ = all_scores_[:,
                                           (classes_per_task * task_id):(classes_per_task * (task_id + 1))]
                        _, temp_y_ = torch.max(temp_scores_, dim=1)
                        scores_.append(temp_scores_)
                        y_.append(temp_y_)

                # Only keep predicted y/scores if required (as otherwise unnecessary computations will be done)
                y_ = y_ if (model.replay_targets == "hard") else None
                scores_ = scores_ if (model.replay_targets == "soft") else None
                peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))

            # Validation for early stopping 
            if early_stop and valid_datasets and (batch_index % 100 == 0): 
                prec = evaluate.validate(
                    model, valid_datasets[task-1], verbose=False, test_size=None, task=task, 
                    allowed_classes=list(range(classes_per_task*(task-1), classes_per_task*(task))) if scenario=="task" else list(range(task))
                ) 
                if prec < prev_prec: 
                    prev_prec = 0.0
                    break 
                prev_prec = prec 

            #---> Train MAIN MODEL
            if batch_index <= iters:
                peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))
                if replay_mode == "naive-rehearsal": 
                    replayed_data = replay_buffer.replay(batch_size)
                    if replayed_data: 
                        x_, y_ = zip(*replayed_data)
                        x_, y_ = torch.stack(x_), torch.tensor(y_)
                        x_ = x_.to(device) 
                        y_ = y_.to(device) 
                        if scenario == "task": 
                            y_ = [y_]
                # Train the main model with this batch
                loss_dict = model.train_a_batch(x, y, x_=x_, y_=y_, scores=scores, scores_=scores_,
                                                active_classes=active_classes, task=task, rnt = 1./task)
                if replay_mode == "naive-rehearsal": 
                    replay_buffer.add(zip(x, y))
                peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))

                # Update running parameter importance estimates in W
                if isinstance(model, ContinualLearner) and (model.si_c>0):
                    for n, p in model.named_parameters():
                        if p.requires_grad:
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad*(p.detach()-p_old[n]))
                            p_old[n] = p.detach().clone()

                # Fire callbacks (for visualization of training-progress / evaluating performance after each task)
                for loss_cb in loss_cbs:
                    if loss_cb is not None:
                        loss_cb(progress, batch_index, loss_dict, task=task)
                for eval_cb in eval_cbs:
                    if eval_cb is not None:
                        eval_cb(model, batch_index, task=task)
                if model.label == "VAE":
                    for sample_cb in sample_cbs:
                        if sample_cb is not None:
                            sample_cb(model, batch_index, task=task)


            #---> Train GENERATOR
            if generator is not None and batch_index <= gen_iters:

                # Train the generator with this batch
                peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))
                loss_dict = generator.train_a_batch(x, y, x_=x_, y_=y_, scores_=scores_, active_classes=active_classes,
                                                    task=task, rnt=1./task)
                peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))

                # Fire callbacks on each iteration
                for loss_cb in gen_loss_cbs:
                    if loss_cb is not None:
                        loss_cb(progress_gen, batch_index, loss_dict, task=task)
                for sample_cb in sample_cbs:
                    if sample_cb is not None:
                        sample_cb(generator, batch_index, task=task)

        if validation and valid_datasets: 
            v_precs = [evaluate.validate(
                model, valid_datasets[i-1], verbose=False, test_size=None, task=i, 
                allowed_classes=list(range(classes_per_task*(i-1), classes_per_task*(i))) if scenario=="task" else list(range(task))
            ) for i in range(1, task+1)]
            t_precs = [evaluate.validate(
                model, train_datasets[i-1], verbose=False, test_size=None, task=i, 
                allowed_classes=list(range(classes_per_task*(i-1), classes_per_task*(i))) if scenario=="task" else list(range(task))
            ) for i in range(1, task+1)]
            valid_precs.append((task, batch_index, v_precs))
            train_precs.append((task, batch_index, t_precs))

        ##----------> UPON FINISHING EACH TASK...
        if replay_mode == "naive-rehearsal": 
            replay_buffer.update()

        # Close progres-bar(s)
        progress.close()
        if generator is not None:
            progress_gen.close()

        # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty
        if isinstance(model, ContinualLearner) and (model.ewc_lambda>0):
            # -find allowed classes
            allowed_classes = list(
                range(classes_per_task*(task-1), classes_per_task*task)
            ) if scenario=="task" else (list(range(classes_per_task*task)) if scenario=="class" else None)
            # -if needed, apply correct task-specific mask
            if model.mask_dict is not None:
                model.apply_XdGmask(task=task)
            # -estimate FI-matrix
            model.estimate_fisher(training_dataset, allowed_classes=allowed_classes)

        # SI: calculate and update the normalized path integral
        if isinstance(model, ContinualLearner) and (model.si_c>0):
            model.update_omega(W, model.epsilon)

        # EXEMPLARS: update exemplar sets
        if (add_exemplars or use_exemplars) or replay_mode=="exemplars":
            exemplars_per_class = int(np.floor(model.memory_budget / (classes_per_task*task)))
            # reduce examplar-sets
            model.reduce_exemplar_sets(exemplars_per_class)
            # for each new class trained on, construct examplar-set
            new_classes = list(range(classes_per_task)) if scenario=="domain" else list(range(classes_per_task*(task-1),
                                                                                              classes_per_task*task))
            for class_id in new_classes:
                # create new dataset containing only all examples of this class
                class_dataset = SubDataset(original_dataset=train_dataset, sub_labels=[class_id])
                # based on this dataset, construct new exemplar-set for this class
                model.construct_exemplar_set(dataset=class_dataset, n=exemplars_per_class)
            model.compute_means = True

        # Calculate statistics required for metrics
        for metric_cb in metric_cbs:
            if metric_cb is not None:
                metric_cb(model, iters, task=task)

        # REPLAY: update source for replay
        previous_model = copy.deepcopy(model).eval()
        if replay_mode == 'generative':
            Generative = True
            previous_generator = copy.deepcopy(generator).eval() if generator is not None else previous_model
        elif replay_mode == 'current':
            Current = True
        elif replay_mode in ('exemplars', 'exact'):
            Exact = True
            if replay_mode == "exact":
                previous_datasets = train_datasets[:task]
            else:
                if scenario == "task":
                    previous_datasets = []
                    for task_id in range(task):
                        previous_datasets.append(
                            ExemplarDataset(
                                model.exemplar_sets[
                                (classes_per_task * task_id):(classes_per_task * (task_id + 1))],
                                target_transform=lambda y, x=classes_per_task * task_id: y + x)
                        )
                else:
                    target_transform = (lambda y, x=classes_per_task: y % x) if scenario == "domain" else None
                    previous_datasets = [
                        ExemplarDataset(model.exemplar_sets, target_transform=target_transform)]
        peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))
    peak_ramu = max(peak_ramu, ramu.compute("TRAINING"))
    print("PEAK TRAINING RAM:", peak_ramu)
    if validation: 
        return (valid_precs, train_precs)
    return None
            store_labels(label_file, dataset.class_names)
            num_classes = len(dataset.class_names)
        elif args.dataset_type == 'open_images':
            dataset = OpenImagesDataset(dataset_path,
                 transform=train_transform, target_transform=target_transform,
                 dataset_type="train", balance_data=args.balance_data)
            label_file = os.path.join(args.checkpoint_folder, "open-images-model-labels.txt")
            store_labels(label_file, dataset.class_names)
            logging.info(dataset)
            num_classes = len(dataset.class_names)
        else:
            raise ValueError(f"Dataset tpye {args.dataset_type} is not supported.")
        datasets.append(dataset)

    logging.info(f"Stored labels into file {label_file}.")
    train_dataset = ConcatDataset(datasets)
    logging.info("Train dataset size: {}".format(len(train_dataset)))
    train_loader = DataLoader(train_dataset, args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=True)
    logging.info("Prepare Validation datasets.")
    if args.dataset_type == "wildlife":
        val_dataset = WildlifeDataset(args.validation_dataset, transform=test_transform,
                                 target_transform=target_transform, is_test=True)
    elif args.dataset_type == 'open_images':
        val_dataset = OpenImagesDataset(dataset_path,
                                        transform=test_transform, target_transform=target_transform,
                                        dataset_type="test")
        logging.info(val_dataset)
    logging.info("validation dataset size: {}".format(len(val_dataset)))
Пример #16
0
                    plt.show()


if __name__ == "__main__":
    from torchvision.datasets import MNIST, USPS, FashionMNIST, CIFAR10
    from torchtext.datasets import AG_NEWS

    n = None
    # semisupervised_proportion = .2

    e = DEN(n_components=2, internal_dim=128)

    USPS_data_train = USPS("./", train=True, download=True)
    USPS_data_test = USPS("./", train=False, download=True)
    USPS_data = ConcatDataset([USPS_data_test, USPS_data_train])
    X, y = zip(*USPS_data)

    y_numpy = np.array(y[:n])
    X_numpy = np.array(
        [np.asarray(X[i]) for i in range(n if n is not None else len(X))])
    X = torch.Tensor(X_numpy).unsqueeze(1)

    # which = np.random.choice(len(y_numpy), int((1-semisupervised_proportion)*len(y_numpy)), replace = False)
    # y_for_verification = copy.deepcopy(y_numpy)
    # y_numpy[which] = -1

    # news_train, news_test = AG_NEWS('./', ngrams = 1)
    # X, y = zip(*([item[1], item[0]] for item in news_test))
    # X = X[:n]
    # y = y[:n]
Пример #17
0
def main():
    if not os.path.exists(opt.model_save_file):
        os.makedirs(opt.model_save_file)
    vocab = Vocab(opt.emb_filename)
    log.info('Loading {} Datasets...'.format(opt.dataset))
    log.info('Domains: {}'.format(opt.domains))

    train_sets, dev_sets, test_sets, unlabeled_sets = {}, {}, {}, {}
    for domain in opt.all_domains:
        train_sets[domain], dev_sets[domain], test_sets[domain], unlabeled_sets[domain] = \
            get_fdu_mtl_datasets(vocab, opt.amazon_lang_dir, domain, opt.max_seq_len)
    opt.num_labels = FduMtlDataset.num_labels
    log.info('Done Loading {} Datasets.'.format(opt.dataset))
    train_sampler, test_sampler, dev_sampler = {}, {}, {}
    train_loaders, unlabeled_loaders = {}, {}
    train_iters, unlabeled_iters = {}, {}
    dev_loaders, test_loaders = {}, {}
    my_collate = utils.sorted_collate if opt.model == 'lstm' else utils.unsorted_collate
    for domain in opt.domains:
        # train_loaders[domain] = DataLoader(train_sets[domain],opt.batch_size, shuffle=True, collate_fn = my_collate)
        train_sampler[domain] = RandomSampler(train_sets[domain])
        train_loaders[domain] = DataLoader(train_sets[domain], sampler=train_sampler[domain], batch_size=opt.batch_size)
        train_iters[domain] = iter(train_loaders[domain])

    for domain in opt.dev_domains:
        test_sampler[domain] = RandomSampler(test_sets[domain])
        test_loaders[domain] = DataLoader(test_sets[domain], sampler=test_sampler[domain], batch_size=opt.batch_size)
        dev_sampler[domain] = RandomSampler(dev_sets[domain])
        dev_loaders[domain] = DataLoader(dev_sets[domain], sampler=dev_sampler[domain], batch_size=opt.batch_size)
        # dev_loaders[domain] = DataLoader(dev_sets[domain],opt.batch_size, shuffle=False, collate_fn = my_collate)
        # test_loaders[domain] = DataLoader(test_sets[domain],opt.batch_size, shuffle=False, collate_fn = my_collate)

    for domain in opt.all_domains:
        if domain in opt.unlabeled_domains:
            uset = unlabeled_sets[domain]
        else:
            # for labeled domains, consider which data to use as unlabeled set
            if opt.unlabeled_data == 'both':
                uset = ConcatDataset([train_sets[domain], unlabeled_sets[domain]])
            elif opt.unlabeled_data == 'unlabeled':
                uset = unlabeled_sets[domain]
            elif opt.unlabeled_data == 'train':
                uset = train_sets[domain]
            else:
                raise Exception('Unknown options for the unlabeled data usage: {}'.format(opt.unlabeled_data))
        # unlabeled_loaders[domain] = DataLoader(uset,opt.batch_size, shuffle=True, collate_fn = my_collate)
        uset_sampler = RandomSampler(uset)
        unlabeled_loaders[domain] = DataLoader(uset, sampler=uset_sampler, batch_size=opt.batch_size)
        unlabeled_iters[domain] = iter(unlabeled_loaders[domain])

    if opt.shared:
        log.info('Starting training shared_nobert')
        cv = train_shared_nobert(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders)
        log.info('Training done...')
        acc = sum(cv['valid'].values()) / len(cv['valid'])
        log.info('Validation Set Domain Average\t{}'.format(acc))
        test_acc = sum(cv['test'].values()) / len(cv['test'])
        log.info('Test Set Domain Average\t{}'.format(test_acc))

    if opt.shared_man:
        log.info('Starting training shared_man_nobert')
        cv = train_shared_man_nobert(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders)
        log.info('Training done...')
        acc = sum(cv['valid'].values()) / len(cv['valid'])
        log.info('Validation Set Domain Average\t{}'.format(acc))
        test_acc = sum(cv['test'].values()) / len(cv['test'])
        log.info('Test Set Domain Average\t{}'.format(test_acc))

    if opt.private:
        log.info('Starting training private_nobert')
        cv = train_private_nobert(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders)
        log.info('Training done...')
        acc = sum(cv['valid'].values()) / len(cv['valid'])
        log.info('Validation Set Domain Average\t{}'.format(acc))
        test_acc = sum(cv['test'].values()) / len(cv['test'])
        log.info('Test Set Domain Average\t{}'.format(test_acc))

    if opt.shared_private_man:
        log.info('Starting training shared_private_man_nobert')
        cv = train_nobert(vocab, train_loaders, unlabeled_loaders, train_iters, unlabeled_iters, dev_loaders, test_loaders)
        log.info('Training done...')
        acc = sum(cv['valid'].values()) / len(cv['valid'])
        log.info('Validation Set Domain Average\t{}'.format(acc))
        test_acc = sum(cv['test'].values()) / len(cv['test'])
        log.info('Test Set Domain Average\t{}'.format(test_acc))
    
    return cv
Пример #18
0
def run(args, verbose=False):

    # Create plots- and results-directories if needed
    if not os.path.isdir(args.r_dir):
        os.mkdir(args.r_dir)
    if args.pdf and not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    # If only want param-stamp, get it and exit
    if args.get_stamp:
        from param_stamp import get_param_stamp_from_args
        print(get_param_stamp_from_args(args=args))
        exit()

    # Use cuda?
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")

    # Report whether cuda is used
    if verbose:
        print("CUDA is {}used".format("" if cuda else "NOT(!!) "))

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)

    #-------------------------------------------------------------------------------------------------#

    #----------------#
    #----- DATA -----#
    #----------------#

    # Prepare data for chosen experiment
    if verbose:
        print("\nPreparing the data...")
    (train_datasets,
     test_datasets), config, classes_per_task = get_multitask_experiment(
         name=args.experiment,
         scenario=args.scenario,
         tasks=args.tasks,
         data_dir=args.d_dir,
         normalize=True if utils.checkattr(args, "normalize") else False,
         augment=True if utils.checkattr(args, "augment") else False,
         verbose=verbose,
         exception=True if args.seed < 10 else False,
         only_test=(not args.train))

    #-------------------------------------------------------------------------------------------------#

    #----------------------#
    #----- MAIN MODEL -----#
    #----------------------#

    # Define main model (i.e., classifier, if requested with feedback connections)
    if verbose and (utils.checkattr(args, "pre_convE") or utils.checkattr(args, "pre_convD")) and \
            (hasattr(args, "depth") and args.depth>0):
        print("\nDefining the model...")
    if utils.checkattr(args, 'feedback'):
        model = define.define_autoencoder(args=args,
                                          config=config,
                                          device=device)
    else:
        model = define.define_classifier(args=args,
                                         config=config,
                                         device=device)

    # Initialize / use pre-trained / freeze model-parameters
    # - initialize (pre-trained) parameters
    model = define.init_params(model, args)
    # - freeze weights of conv-layers?
    if utils.checkattr(args, "freeze_convE"):
        for param in model.convE.parameters():
            param.requires_grad = False
    if utils.checkattr(args, 'feedback') and utils.checkattr(
            args, "freeze_convD"):
        for param in model.convD.parameters():
            param.requires_grad = False

    # Define optimizer (only optimize parameters that "requires_grad")
    model.optim_list = [
        {
            'params': filter(lambda p: p.requires_grad, model.parameters()),
            'lr': args.lr
        },
    ]
    model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))

    #-------------------------------------------------------------------------------------------------#

    #----------------------------------------------------#
    #----- CL-STRATEGY: REGULARIZATION / ALLOCATION -----#
    #----------------------------------------------------#

    # Elastic Weight Consolidation (EWC)
    if isinstance(model, ContinualLearner) and utils.checkattr(args, 'ewc'):
        model.ewc_lambda = args.ewc_lambda if args.ewc else 0
        model.fisher_n = args.fisher_n
        model.online = utils.checkattr(args, 'online')
        if model.online:
            model.gamma = args.gamma

    # Synpatic Intelligence (SI)
    if isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'):
        model.si_c = args.si_c if args.si else 0
        model.epsilon = args.epsilon

    # XdG: create for every task a "mask" for each hidden fully connected layer
    if isinstance(model, ContinualLearner) and utils.checkattr(
            args, 'xdg') and args.xdg_prop > 0:
        model.define_XdGmask(gating_prop=args.xdg_prop, n_tasks=args.tasks)

    #-------------------------------------------------------------------------------------------------#

    #-------------------------------#
    #----- CL-STRATEGY: REPLAY -----#
    #-------------------------------#

    # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature)
    if isinstance(model, ContinualLearner) and hasattr(
            args, 'replay') and not args.replay == "none":
        model.replay_targets = "soft" if args.distill else "hard"
        model.KD_temp = args.temp

    # If needed, specify separate model for the generator
    train_gen = (hasattr(args, 'replay') and args.replay == "generative"
                 and not utils.checkattr(args, 'feedback'))
    if train_gen:
        # Specify architecture
        generator = define.define_autoencoder(args,
                                              config,
                                              device,
                                              generator=True)

        # Initialize parameters
        generator = define.init_params(generator, args)
        # -freeze weights of conv-layers?
        if utils.checkattr(args, "freeze_convE"):
            for param in generator.convE.parameters():
                param.requires_grad = False
        if utils.checkattr(args, "freeze_convD"):
            for param in generator.convD.parameters():
                param.requires_grad = False

        # Set optimizer(s)
        generator.optim_list = [
            {
                'params': filter(lambda p: p.requires_grad,
                                 generator.parameters()),
                'lr': args.lr_gen if hasattr(args, 'lr_gen') else args.lr
            },
        ]
        generator.optimizer = optim.Adam(generator.optim_list,
                                         betas=(0.9, 0.999))
    else:
        generator = None

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- REPORTING -----#
    #---------------------#

    # Get parameter-stamp (and print on screen)
    if verbose:
        print("\nParameter-stamp...")
    param_stamp = get_param_stamp(
        args,
        model.name,
        verbose=verbose,
        replay=True if
        (hasattr(args, 'replay') and not args.replay == "none") else False,
        replay_model_name=generator.name if
        (hasattr(args, 'replay') and args.replay in ("generative")
         and not utils.checkattr(args, 'feedback')) else None,
    )

    # Print some model-characteristics on the screen
    if verbose:
        # -main model
        utils.print_model_info(model, title="MAIN MODEL")
        # -generator
        if generator is not None:
            utils.print_model_info(generator, title="GENERATOR")

    # Define [progress_dicts] to keep track of performance during training for storing and for later plotting in pdf
    precision_dict = evaluate.initiate_precision_dict(args.tasks)

    # Prepare for plotting in visdom
    visdom = None
    if args.visdom:
        env_name = "{exp}{tasks}-{scenario}".format(exp=args.experiment,
                                                    tasks=args.tasks,
                                                    scenario=args.scenario)
        replay_statement = "{mode}{fb}{con}{gat}{int}{dis}{b}{u}".format(
            mode=args.replay,
            fb="Rtf" if utils.checkattr(args, "feedback") else "",
            con="Con" if (hasattr(args, "prior") and args.prior == "GMM"
                          and utils.checkattr(args, "per_class")) else "",
            gat="Gat{}".format(args.dg_prop) if
            (utils.checkattr(args, "dg_gates") and hasattr(args, "dg_prop")
             and args.dg_prop > 0) else "",
            int="Int" if utils.checkattr(args, "hidden") else "",
            dis="Dis" if args.replay == "generative" and args.distill else "",
            b="" if
            (args.batch_replay is None or args.batch_replay == args.batch) else
            "-br{}".format(args.batch_replay),
            u="" if args.g_fc_uni == args.fc_units else "-gu{}".format(
                args.g_fc_uni)) if (hasattr(args, "replay")
                                    and not args.replay == "none") else "NR"
        graph_name = "{replay}{syn}{ewc}{xdg}".format(
            replay=replay_statement,
            syn="-si{}".format(args.si_c)
            if utils.checkattr(args, 'si') else "",
            ewc="-ewc{}{}".format(
                args.ewc_lambda, "-O{}".format(args.gamma)
                if utils.checkattr(args, "online") else "") if utils.checkattr(
                    args, 'ewc') else "",
            xdg="" if (not utils.checkattr(args, 'xdg')) or args.xdg_prop == 0
            else "-XdG{}".format(args.xdg_prop),
        )
        visdom = {'env': env_name, 'graph': graph_name}

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- CALLBACKS -----#
    #---------------------#

    g_iters = args.g_iters if hasattr(args, 'g_iters') else args.iters

    # Callbacks for reporting on and visualizing loss
    generator_loss_cbs = [
        cb._VAE_loss_cb(
            log=args.loss_log,
            visdom=visdom,
            replay=(hasattr(args, "replay") and not args.replay == "none"),
            model=model if utils.checkattr(args, 'feedback') else generator,
            tasks=args.tasks,
            iters_per_task=args.iters
            if utils.checkattr(args, 'feedback') else g_iters)
    ] if (train_gen or utils.checkattr(args, 'feedback')) else [None]
    solver_loss_cbs = [
        cb._solver_loss_cb(log=args.loss_log,
                           visdom=visdom,
                           model=model,
                           iters_per_task=args.iters,
                           tasks=args.tasks,
                           replay=(hasattr(args, "replay")
                                   and not args.replay == "none"))
    ] if (not utils.checkattr(args, 'feedback')) else [None]

    # Callbacks for evaluating and plotting generated / reconstructed samples
    no_samples = (utils.checkattr(args, "no_samples")
                  or (utils.checkattr(args, "hidden")
                      and hasattr(args, 'depth') and args.depth > 0))
    sample_cbs = [
        cb._sample_cb(log=args.sample_log,
                      visdom=visdom,
                      config=config,
                      test_datasets=test_datasets,
                      sample_size=args.sample_n,
                      iters_per_task=g_iters)
    ] if ((train_gen or utils.checkattr(args, 'feedback'))
          and not no_samples) else [None]

    # Callbacks for reporting and visualizing accuracy, and visualizing representation extracted by main model
    # -visdom (i.e., after each [prec_log]
    eval_cb = cb._eval_cb(
        log=args.prec_log,
        test_datasets=test_datasets,
        visdom=visdom,
        precision_dict=None,
        iters_per_task=args.iters,
        test_size=args.prec_n,
        classes_per_task=classes_per_task,
        scenario=args.scenario,
    )
    # -pdf / reporting: summary plots (i.e, only after each task)
    eval_cb_full = cb._eval_cb(
        log=args.iters,
        test_datasets=test_datasets,
        precision_dict=precision_dict,
        iters_per_task=args.iters,
        classes_per_task=classes_per_task,
        scenario=args.scenario,
    )
    # -visualize feature space
    latent_space_cb = cb._latent_space_cb(
        log=args.iters,
        datasets=test_datasets,
        visdom=visdom,
        iters_per_task=args.iters,
        sample_size=400,
    )
    # -collect them in <lists>
    eval_cbs = [eval_cb, eval_cb_full, latent_space_cb]

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- TRAINING -----#
    #--------------------#

    if args.train:
        if verbose:
            print("\nTraining...")
        # Train model
        train_cl(
            model,
            train_datasets,
            replay_mode=args.replay if hasattr(args, 'replay') else "none",
            scenario=args.scenario,
            classes_per_task=classes_per_task,
            iters=args.iters,
            batch_size=args.batch,
            batch_size_replay=args.batch_replay if hasattr(
                args, 'batch_replay') else None,
            generator=generator,
            gen_iters=g_iters,
            gen_loss_cbs=generator_loss_cbs,
            feedback=utils.checkattr(args, 'feedback'),
            sample_cbs=sample_cbs,
            eval_cbs=eval_cbs,
            loss_cbs=generator_loss_cbs
            if utils.checkattr(args, 'feedback') else solver_loss_cbs,
            args=args,
            reinit=utils.checkattr(args, 'reinit'),
            only_last=utils.checkattr(args, 'only_last'))
        # Save evaluation metrics measured throughout training
        file_name = "{}/dict-{}".format(args.r_dir, param_stamp)
        utils.save_object(precision_dict, file_name)
        # Save trained model(s), if requested
        if args.save:
            save_name = "mM-{}".format(param_stamp) if (
                not hasattr(args, 'full_stag')
                or args.full_stag == "none") else "{}-{}".format(
                    model.name, args.full_stag)
            utils.save_checkpoint(model,
                                  args.m_dir,
                                  name=save_name,
                                  verbose=verbose)
            if generator is not None:
                save_name = "gM-{}".format(param_stamp) if (
                    not hasattr(args, 'full_stag')
                    or args.full_stag == "none") else "{}-{}".format(
                        generator.name, args.full_stag)
                utils.save_checkpoint(generator,
                                      args.m_dir,
                                      name=save_name,
                                      verbose=verbose)

    else:
        # Load previously trained model(s) (if goal is to only evaluate previously trained model)
        if verbose:
            print("\nLoading parameters of the previously trained models...")
        load_name = "mM-{}".format(param_stamp) if (
            not hasattr(args, 'full_ltag')
            or args.full_ltag == "none") else "{}-{}".format(
                model.name, args.full_ltag)
        utils.load_checkpoint(
            model,
            args.m_dir,
            name=load_name,
            verbose=verbose,
            add_si_buffers=(isinstance(model, ContinualLearner)
                            and utils.checkattr(args, 'si')))
        if generator is not None:
            load_name = "gM-{}".format(param_stamp) if (
                not hasattr(args, 'full_ltag')
                or args.full_ltag == "none") else "{}-{}".format(
                    generator.name, args.full_ltag)
            utils.load_checkpoint(generator,
                                  args.m_dir,
                                  name=load_name,
                                  verbose=verbose)

    #-------------------------------------------------------------------------------------------------#

    #-----------------------------------#
    #----- EVALUATION of CLASSIFIER-----#
    #-----------------------------------#

    if verbose:
        print("\n\nEVALUATION RESULTS:")

    # Evaluate precision of final model on full test-set
    precs = [
        evaluate.validate(
            model,
            test_datasets[i],
            verbose=False,
            test_size=None,
            task=i + 1,
            allowed_classes=list(
                range(classes_per_task * i, classes_per_task *
                      (i + 1))) if args.scenario == "task" else None)
        for i in range(args.tasks)
    ]
    average_precs = sum(precs) / args.tasks
    # -print on screen
    if verbose:
        print("\n Precision on test-set:")
        for i in range(args.tasks):
            print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
        print('=> Average precision over all {} tasks: {:.4f}\n'.format(
            args.tasks, average_precs))
    # -write out to text file
    output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w')
    output_file.write('{}\n'.format(average_precs))
    output_file.close()

    #-------------------------------------------------------------------------------------------------#

    #-----------------------------------#
    #----- EVALUATION of GENERATOR -----#
    #-----------------------------------#

    if (utils.checkattr(args, 'feedback') or train_gen
        ) and args.experiment == "CIFAR100" and args.scenario == "class":

        # Dataset and model to be used
        test_set = ConcatDataset(test_datasets)
        gen_model = model if utils.checkattr(args, 'feedback') else generator
        gen_model.eval()

        # Evaluate log-likelihood of generative model on combined test-set (with S=100 importance samples per datapoint)
        ll_per_datapoint = gen_model.estimate_loglikelihood(
            test_set, S=100, batch_size=args.batch)
        if verbose:
            print('=> Log-likelihood on test set: {:.4f} +/- {:.4f}\n'.format(
                np.mean(ll_per_datapoint), np.sqrt(np.var(ll_per_datapoint))))
        # -write out to text file
        output_file = open("{}/ll-{}.txt".format(args.r_dir, param_stamp), 'w')
        output_file.write('{}\n'.format(np.mean(ll_per_datapoint)))
        output_file.close()

        # Evaluate reconstruction error (averaged over number of input units)
        re_per_datapoint = gen_model.calculate_recon_error(
            test_set, batch_size=args.batch, average=True)
        if verbose:
            print(
                '=> Reconstruction error (per input unit) on test set: {:.4f} +/- {:.4f}\n'
                .format(np.mean(re_per_datapoint),
                        np.sqrt(np.var(re_per_datapoint))))
        # -write out to text file
        output_file = open("{}/re-{}.txt".format(args.r_dir, param_stamp), 'w')
        output_file.write('{}\n'.format(np.mean(re_per_datapoint)))
        output_file.close()

        # Try loading the classifier (our substitute for InceptionNet) for calculating IS, FID and Recall & Precision
        # -define model
        config['classes'] = 100
        pretrained_classifier = define.define_classifier(args=args,
                                                         config=config,
                                                         device=device)
        pretrained_classifier.hidden = False
        # -load pretrained weights
        eval_tag = "" if args.eval_tag == "none" else "-{}".format(
            args.eval_tag)
        try:
            utils.load_checkpoint(pretrained_classifier,
                                  args.m_dir,
                                  verbose=True,
                                  name="{}{}".format(
                                      pretrained_classifier.name, eval_tag))
            FileFound = True
        except FileNotFoundError:
            if verbose:
                print("= Could not find model {}{} in {}".format(
                    pretrained_classifier.name, eval_tag, args.m_dir))
                print("= IS, FID and Precision & Recall not computed!")
            FileFound = False
        pretrained_classifier.eval()

        # Only continue with computing these measures if the requested classifier network (using --eval-tag) was found
        if FileFound:
            # Preparations
            total_n = len(test_set)
            n_repeats = int(np.ceil(total_n / args.batch))
            # -sample data from generator (for IS, FID and Precision & Recall)
            gen_x = gen_model.sample(size=total_n, only_x=True)
            # -generate predictions for generated data (for IS)
            gen_pred = []
            for i in range(n_repeats):
                x = gen_x[(i *
                           args.batch):int(min(((i + 1) *
                                                args.batch), total_n))]
                with torch.no_grad():
                    gen_pred.append(
                        F.softmax(pretrained_classifier.hidden_to_output(x)
                                  if args.hidden else pretrained_classifier(x),
                                  dim=1).cpu().numpy())
            gen_pred = np.concatenate(gen_pred)
            # -generate embeddings for generated data (for FID and Precision & Recall)
            gen_emb = []
            for i in range(n_repeats):
                with torch.no_grad():
                    gen_emb.append(
                        pretrained_classifier.feature_extractor(
                            gen_x[(i * args.batch
                                   ):int(min(((i + 1) *
                                              args.batch), total_n))],
                            from_hidden=args.hidden).cpu().numpy())
            gen_emb = np.concatenate(gen_emb)
            # -generate embeddings for test data (for FID and Precision & Recall)
            data_loader = utils.get_data_loader(test_set,
                                                batch_size=args.batch,
                                                cuda=cuda)
            real_emb = []
            for real_x, _ in data_loader:
                with torch.no_grad():
                    real_emb.append(
                        pretrained_classifier.feature_extractor(
                            real_x.to(device)).cpu().numpy())
            real_emb = np.concatenate(real_emb)

            # Calculate "Inception Score" (IS)
            py = gen_pred.mean(axis=0)
            is_per_datapoint = []
            for i in range(len(gen_pred)):
                pyx = gen_pred[i, :]
                is_per_datapoint.append(entropy(pyx, py))
            IS = np.exp(np.mean(is_per_datapoint))
            if verbose:
                print('=> Inception Score = {:.4f}\n'.format(IS))
            # -write out to text file
            output_file = open(
                "{}/is{}-{}.txt".format(args.r_dir, eval_tag, param_stamp),
                'w')
            output_file.write('{}\n'.format(IS))
            output_file.close()

            ## Calculate "Frechet Inception Distance" (FID)
            FID = fid.calculate_fid_from_embedding(gen_emb, real_emb)
            if verbose:
                print('=> Frechet Inception Distance = {:.4f}\n'.format(FID))
            # -write out to text file
            output_file = open(
                "{}/fid{}-{}.txt".format(args.r_dir, eval_tag, param_stamp),
                'w')
            output_file.write('{}\n'.format(FID))
            output_file.close()

            # Calculate "Precision & Recall"-curves
            precision, recall = pr.compute_prd_from_embedding(
                gen_emb, real_emb)
            # -write out to text files
            file_name = "{}/precision{}-{}.txt".format(args.r_dir, eval_tag,
                                                       param_stamp)
            with open(file_name, 'w') as f:
                for item in precision:
                    f.write("%s\n" % item)
            file_name = "{}/recall{}-{}.txt".format(args.r_dir, eval_tag,
                                                    param_stamp)
            with open(file_name, 'w') as f:
                for item in recall:
                    f.write("%s\n" % item)

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- PLOTTING -----#
    #--------------------#

    # If requested, generate pdf
    if args.pdf:
        # -open pdf
        plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp)
        pp = evaluate.visual.plt.open_pdf(plot_name)

        # -show metrics reflecting progression during training
        if args.train and (not utils.checkattr(args, 'only_last')):
            # -create list to store all figures to be plotted.
            figure_list = []
            # -generate figures (and store them in [figure_list])
            figure = evaluate.visual.plt.plot_lines(
                precision_dict["all_tasks"],
                x_axes=precision_dict["x_task"],
                line_names=[
                    '{} {}'.format(
                        "episode" if args.scenario == "class" else "task",
                        i + 1) for i in range(args.tasks)
                ],
                xlabel="# of {}s".format("episode" if args.scenario ==
                                         "class" else "task"),
                ylabel="Test accuracy")
            figure_list.append(figure)
            figure = evaluate.visual.plt.plot_lines(
                [precision_dict["average"]],
                x_axes=precision_dict["x_task"],
                line_names=[
                    'Average based on all {}s so far'.format((
                        "digit" if args.experiment == "splitMNIST" else
                        "classe") if args.scenario else "task")
                ],
                xlabel="# of {}s".format("episode" if args.scenario ==
                                         "class" else "task"),
                ylabel="Test accuracy")
            figure_list.append(figure)
            # -add figures to pdf
            for figure in figure_list:
                pp.savefig(figure)

        gen_eval = (utils.checkattr(args, 'feedback') or train_gen)
        # -show samples (from main model or separate generator)
        if gen_eval and not no_samples:
            evaluate.show_samples(
                model if utils.checkattr(args, 'feedback') else generator,
                config,
                size=args.sample_n,
                pdf=pp,
                title="Generated samples (by final model)")

        # -plot "Precision & Recall"-curve
        if gen_eval and args.experiment == "CIFAR100" and args.scenario == "class" and FileFound:
            figure = evaluate.visual.plt.plot_pr_curves([[precision]],
                                                        [[recall]])
            pp.savefig(figure)

        # -close pdf
        pp.close()

        # -print name of generated plot on screen
        if verbose:
            print("\nGenerated plot: {}\n".format(plot_name))
Пример #19
0
for test_seq in seqs:
    train_seqs = seqs.copy()
    train_seqs.remove(test_seq)
    val_seqs = [test_seq]

    train_data = []
    for seq in train_seqs:
        for i in range(len(train_conds)):
            for j in range(i, len(train_conds)):
                cond1 = train_conds[i]
                cond2 = train_conds[j]
                print('Train {}: {} --> {}'.format(seq, cond1, cond2))
                data = vkitti.TorchDataset(opts, seq, cond1, cond2,
                                           opts.random_crop)
                train_data.append(data)
    train_data = ConcatDataset(train_data)

    val_data = []
    for seq in val_seqs:
        for i in range(len(val_conds)):
            for j in range(i, len(val_conds)):
                cond1 = val_conds[i]
                cond2 = val_conds[j]
                print('Val {}: {} --> {}'.format(seq, cond1, cond2))
                data = vkitti.TorchDataset(opts, seq, cond1, cond2, False)
                val_data.append(data)
    val_data = ConcatDataset(val_data)

    test_data = vkitti.TorchDataset(opts, test_seq, test_conds[0],
                                    test_conds[1], False)
Пример #20
0
def train_cl(model, train_datasets, replay_mode="none", scenario="class", classes_per_task=None,
             iters=2000, batch_size=32, collate_fn=None, visualize=True,
             generator=None, gen_iters=0, gen_loss_cbs=list(), loss_cbs=list(), eval_cbs=list(), sample_cbs=list()):
    '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode].

    [model]             <nn.Module> main model to optimize across all tasks
    [train_datasets]    <list> with for each task the training <DataSet>
    [replay_mode]       <str>, choice from "generative", "exact", "current", "offline" and "none"
    [scenario]          <str>, choice from "task", "domain" and "class"
    [classes_per_task]  <int>, # of classes per task
    [iters]             <int>, # of optimization-steps (i.e., # of batches) per task
    [visualize]         <bool>, whether all losses should be calculated for plotting (even if not used)
    [generator]         None or <nn.Module>, if a seperate generative model should be trained (for [gen_iters] per task)
    [*_cbs]             <list> of call-back functions to evaluate training-progress'''


    # Set model in training-mode
    model.train()

    # Use cuda?
    cuda = model._is_on_cuda()
    device = model._device()

    # Initiate possible sources for replay (no replay for 1st task)
    previous_model = previous_scholar = previous_datasets = None
    exact_replay = generative_replay = current_replay = False

    # Register starting param-values (needed for "intelligent synapses").
    if isinstance(model, ContinualLearner) and (model.si_c>0 or visualize):
        for n, p in model.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '__')
                model.register_buffer('{}_SI_prev_task'.format(n), p.data.clone())

    # Loop over all tasks.
    for task, train_dataset in enumerate(train_datasets, 1):

        # Do not train if non-positive iterations
        if iters <= 0:
            return

        # If offline replay-setting, create large database of all tasks so far
        if replay_mode=="offline" and (not scenario=="task"):
            train_dataset = ConcatDataset(train_datasets[:task])
        # -but if "offline"+"task"-scenario: all tasks so far included in 'exact replay' & no current batch
        if replay_mode=="offline" and scenario == "task":
            exact_replay = True


        ####################################### MAIN MODEL #######################################

        # Prepare <dicts> to store running importance estimates and param-values before update ("Synaptic Intelligence")
        if isinstance(model, ContinualLearner) and (model.si_c>0 or visualize):
            W = {}
            p_old = {}
            for n, p in model.named_parameters():
                if p.requires_grad:
                    n = n.replace('.', '__')
                    W[n] = p.data.clone().zero_()
                    p_old[n] = p.data.clone()

        # Reset state of optimizer for every task (if requested)
        if model.optim_type=="adam_reset":
            model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))

        # Initialize # iters left on current data-loader(s)
        iters_left = iters_left_previous = 1
        if scenario=="task":
            up_to_task = task if replay_mode=="offline" else task-1
            iters_left_previous = [1]*up_to_task
            data_loader_previous = [None]*up_to_task

        # Define a tqdm progress bar
        progress = tqdm.tqdm(range(1, iters+1))

        # Loop over all iterations
        for batch_index in progress:

            # Update # iters left on current data-loader(s) and, if needed, create new one(s)
            iters_left -= 1
            if iters_left==0:
                data_loader = iter(utils.get_data_loader(train_dataset, batch_size, cuda=cuda, collate_fn=collate_fn,
                                                         drop_last=True))
                iters_left = len(data_loader)
            if exact_replay:
                if scenario=="task":
                    up_to_task = task if replay_mode=="offline" else task-1
                    batch_size_replay = int(np.ceil(batch_size/up_to_task)) if (up_to_task>1) else batch_size
                    # -in incremental task learning scenario, need separate replay for each task
                    for task_id in range(up_to_task):
                        iters_left_previous[task_id] -= 1
                        if iters_left_previous[task_id]==0:
                            data_loader_previous[task_id] = iter(utils.get_data_loader(
                                train_datasets[task_id], batch_size_replay, cuda=cuda,
                                collate_fn=collate_fn, drop_last=True
                            ))
                            iters_left_previous[task_id] = len(data_loader_previous[task_id])
                else:
                    iters_left_previous -= 1
                    if iters_left_previous==0:
                        data_loader_previous = iter(utils.get_data_loader(ConcatDataset(previous_datasets),
                                                                          batch_size, cuda=cuda,
                                                                          collate_fn=collate_fn, drop_last=True))
                        iters_left_previous = len(data_loader_previous)


            #####-----CURRENT BATCH-----#####
            if replay_mode=="offline" and scenario=="task":
                x = y = None
            else:
                x, y = next(data_loader)                                    #--> sample training data of current task
                y = y-classes_per_task*(task-1) if scenario=="task" else y  #--> ITL: adjust y-targets to 'active range'
                x, y = x.to(device), y.to(device)                           #--> transfer them to correct device


            #####-----REPLAYED BATCH-----#####
            if not exact_replay and not generative_replay and not current_replay:
                x_ = y_ = scores_ = None   #-> if no replay

            ##-->> Current Replay <<--##
            if current_replay:
                scores_ = None
                if not scenario=="task":
                    # Use same as CURRENT BATCH to replay
                    x_ = x
                    y_ = y if ((model.replay_targets=="hard") or visualize) else None
                    # Get predicted "logits"/"scores" on replayed data (from previous model)
                    if (model.replay_targets=="soft") or visualize:
                        with torch.no_grad():
                            scores_ = previous_model(x_)
                        scores_ = scores_[:, :(classes_per_task * (task - 1))] if scenario=="class" else scores_
                        # --> ICL: zero probabilities will be added in the [utils.loss_fn_kd]-function
                else:
                    if model.replay_targets=="hard":
                        raise NotImplementedError(
                            "'Current' replay with 'hard targets' not implemented for 'incremental task learning'."
                        )
                    # For each task to replay, use same [x] as in CURRENT BATCH
                    x_ = list()
                    for task_id in range(task-1):
                        x_.append(x)
                    # Get predicted "logits" on replayed data (from previous model)
                    if (model.replay_targets=="soft") or visualize:
                        scores_ = list()
                        for task_id in range(task-1):
                            with torch.no_grad():
                                scores_temp = previous_model(x_[task_id])
                            scores_.append(scores_temp[:, (classes_per_task*task_id):(classes_per_task*(task_id+1))])

            ##-->> Exact Replay <<--##
            if exact_replay:
                scores_ = None
                if not scenario=="task":
                    # Sample replayed training data, wrap in (cuda-)Variables
                    x_, y_ = next(data_loader_previous)
                    x_ = x_.to(device)
                    y_ = y_.to(device) if (model.replay_targets=="hard") or visualize else None
                    # Get predicted "logits"/"scores" on replayed data (from previous model)
                    if (model.replay_targets=="soft") or visualize:
                        with torch.no_grad():
                            scores_ = previous_model(x_)
                        scores_ = scores_[:, :(classes_per_task * (task - 1))] if scenario=="class" else scores_
                        # --> ICL: zero probabilities will be added in the [utils.loss_fn_kd]-function
                else:
                    # Sample replayed training data, wrap in (cuda-)Variables and store in lists
                    x_ = list()
                    y_ = list()
                    up_to_task = task if replay_mode=="offline" else task-1
                    for task_id in range(up_to_task):
                        x_temp, y_temp = next(data_loader_previous[task_id])
                        x_.append(x_temp.to(device))
                        # -only keep [y_] if required (as otherwise unnecessary computations will be done)
                        if (model.replay_targets == "hard") or visualize:
                            y_temp = y_temp - (classes_per_task*task_id) #-> adjust y-targets to 'active range'
                            y_.append(y_temp.to(device))
                        else:
                            y_.append(None)
                    # Get predicted "logits" on replayed data (from previous model)
                    if ((model.replay_targets=="soft") or visualize) and (previous_model is not None):
                        scores_ = list()
                        for task_id in range(up_to_task):
                            with torch.no_grad():
                                scores_temp = previous_model(x_[task_id])
                            scores_.append(scores_temp[:, (classes_per_task*task_id):(classes_per_task*(task_id+1))])

            ##-->> Generative Replay <<--##
            if generative_replay:
                if not scenario=="task":
                    # Which classes could be predicted (=[allowed_predictions])?
                    allowed_predictions = None if scenario=="domain" else list(range(classes_per_task*(task-1)))
                    # Sample replayed data, along with their predicted "logits" (both from previous model / scholar)
                    sample_model = previous_model if generator is None else previous_scholar
                    x_, y_, scores_ = sample_model.sample(batch_size, allowed_predictions=allowed_predictions,
                                                          return_scores=True)
                    # -only keep predicted y/scores if required (as otherwise unnecessary computations will be done)
                    y_ = y_ if ((model.replay_targets=="hard") or visualize) else None
                    scores_ = scores_ if ((model.replay_targets=="soft") or visualize) else None
                else:
                    x_ = list()
                    y_ = list()
                    scores_ = list()
                    # For each previous task, list which classes could be predicted
                    allowed_pred_list = [list(range(classes_per_task*i, classes_per_task*(i+1))) for i in range(task)]
                    for prev_task_id in range(1, task):
                        # Sample replayed data, along with their predicted "logits" (both from previous model / scholar)
                        sample_model = previous_model if generator is None else previous_scholar
                        batch_size_replay = int(np.ceil(batch_size / (task-1))) if (task > 2) else batch_size
                        x_temp, y_temp, scores_temp = sample_model.sample(
                            batch_size_replay, allowed_predictions=allowed_pred_list[prev_task_id-1],
                            return_scores=True,
                        )
                        x_.append(x_temp)
                        # -only keep [y_] / [scores_] if required (as otherwise unnecessary computations will be done)
                        y_.append(y_temp if (model.replay_targets=="hard" or visualize) else None)
                        scores_.append(scores_temp if (model.replay_targets=="soft" or visualize) else None)


            # Find [active_classes]
            active_classes = None  #-> for "domain"-sce, always all classes are active
            if scenario=="task":
                # -for "task"-sce, create <list> with for all tasks so far a <list> with the active classes
                active_classes = [list(range(classes_per_task*i, classes_per_task*(i+1))) for i in range(task)]
            elif scenario=="class":
                # -for "class"-sce, create one <list> with active classes of all tasks so far
                active_classes = list(range(classes_per_task*task))

            # Train the model with this batch
            loss_dict = model.train_a_batch(
                x, y, x_=x_, y_=y_, scores_=scores_, active_classes=active_classes, task=task, rnt = 1./task,
            )

            # Update running parameter importance estimates in W
            if isinstance(model, ContinualLearner) and (model.si_c>0 or visualize):
                for n, p in model.named_parameters():
                    if p.requires_grad:
                        n = n.replace('.', '__')
                        if p.grad is not None:
                            W[n].add_(-p.grad*(p.detach()-p_old[n]))
                        p_old[n] = p.detach().clone()

            # Fire callbacks (for visualization of training-progress / evaluating performance after each task)
            for loss_cb in loss_cbs:
                if loss_cb is not None:
                    loss_cb(progress, batch_index, loss_dict, task=task)
            for eval_cb in eval_cbs:
                if eval_cb is not None:
                    eval_cb(model, batch_index, task=task)
            if model.label=="VAE":
                for sample_cb in sample_cbs:
                    if sample_cb is not None:
                        sample_cb(model, batch_index, task=task)


        ####################################### GENERATOR #######################################

        if generator is not None:

            # Reset state of optimizer for every task (if requested)
            if generator.optim_type=="adam_reset":
                generator.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))

            # Initialize number of iters left on current data-loader(s)
            iters_left = iters_left_previous = 1

            # Define a tqdm progress bar
            progress = tqdm.tqdm(range(1, gen_iters+1))

            # Loop over all iterations.
            for batch_index in progress:

                # Update # iters left on current data-loader(s) and, if needed, create new one(s)
                iters_left -= 1
                if iters_left == 0:
                    data_loader = iter(utils.get_data_loader(train_dataset, batch_size, cuda=cuda,
                                                             collate_fn=collate_fn, drop_last=True))
                    iters_left = len(data_loader)
                if exact_replay:
                    iters_left_previous -= 1
                    if iters_left_previous == 0:
                        data_loader_previous = iter(utils.get_data_loader(ConcatDataset(previous_datasets),
                                                                          batch_size, cuda=cuda,
                                                                          collate_fn=collate_fn, drop_last=True))
                        iters_left_previous = len(data_loader_previous)

                # Sample training data of current task
                x, _ = next(data_loader)
                x = x.to(device)

                # Sample replayed training data
                if exact_replay:
                    x_, _ = next(data_loader_previous)
                    x_ = x_.to(device)
                elif generative_replay:
                    x_, _ = previous_scholar.sample(batch_size)
                elif current_replay:
                    x_ = x
                else:
                    x_ = None

                # Train the generator with this batch
                loss_dict = generator.train_a_batch(x, y=None, x_=x_, y_=None, rnt=1./task)

                # Fire callbacks on each iteration
                for loss_cb in gen_loss_cbs:
                    if loss_cb is not None:
                        loss_cb(progress, batch_index, loss_dict, task=task)
                for sample_cb in sample_cbs:
                    if sample_cb is not None:
                        sample_cb(generator, batch_index, task=task)


        ##----------> UPON FINISHING EACH TASK...

        # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty
        if isinstance(model, ContinualLearner) and (model.ewc_lambda>0 or visualize):
            allowed_classes = list(
                range(classes_per_task*(task-1), classes_per_task*task)
            ) if scenario=="task" else (list(range(classes_per_task*task)) if scenario=="class" else None)
            model.estimate_fisher(train_dataset, allowed_classes=allowed_classes, collate_fn=collate_fn)

        # SI: calculate and update the normalized path integral
        if isinstance(model, ContinualLearner) and (model.si_c>0 or visualize):
            model.update_omega(W, model.epsilon)

        # REPLAY: update source for replay
        previous_model = copy.deepcopy(model)
        previous_model.eval()
        if generator is not None:
            scholar = dgr.Scholar(generator=generator, solver=model)
            previous_scholar = copy.deepcopy(scholar)
        if replay_mode=='generative':
            generative_replay = True
        elif replay_mode=='exact':
            previous_datasets = train_datasets[:task]
            exact_replay = True
        elif replay_mode=='current':
            current_replay = True
Пример #21
0
def meta_train(args, config, train_set, dev_set, label_map, bert_model,
               clf_head):
    save_dir = "./models/{}".format(utils.get_savedir_name())
    tb_writer = SummaryWriter(os.path.join(save_dir, "logs"))

    split_fraction = 1.0 * config.inner_loop_steps / (config.inner_loop_steps +
                                                      1)
    train_set_1, train_set_2 = [], []
    for dataset in train_set:
        ts1 = int(split_fraction * len(dataset))
        ts2 = len(dataset) - ts1
        td1, td2 = torch.utils.data.random_split(
            dataset, [ts1, ts2],
            generator=torch.Generator().manual_seed(config.seed))
        train_set_1.append(td1)
        train_set_2.append(td2)

    train_taskset = data_utils.CustomLangTaskDataset(
        train_set_1, train_type=config.train_type)
    dev_taskset = data_utils.CustomLangTaskDataset(train_set_2)

    eval_set = ConcatDataset(dev_set)
    eval_loader = DataLoader(
        dataset=eval_set,
        batch_size=config.task_batch_size,
        collate_fn=utils.collate_fn,
        shuffle=False,
        num_workers=0,
    )

    num_epochs = config.num_epochs
    task_bs = config.task_batch_size
    inner_loop_steps = config.inner_loop_steps
    num_episodes = len(ConcatDataset(train_set_2)) // task_bs

    meta_clf = meta_utils.ParamMetaSGD(clf_head,
                                       lr=config.inner_lr,
                                       first_order=config.is_fomaml)
    if not config.finetune_enc:
        for param in bert_model.parameters():
            param.requires_grad = False
        extra = []
        meta_encoder = bert_model
    else:
        meta_encoder = meta_utils.ParamMetaSGD(bert_model,
                                               lr=config.inner_lr,
                                               first_order=config.is_fomaml)
        extra = [p for p in meta_encoder.parameters()]

    opt_params = list(meta_clf.parameters()) + extra
    if config.train_type == "metabase":
        opt = Adam(opt_params, lr=config.outer_lr)
    else:
        if config.optim == "adam":
            opt = GDA(
                max_params=train_taskset.parameters(),
                min_params=opt_params,
                lr_max=config.outer_lr,
                lr_min=config.outer_lr,
                device=DEVICE,
            )
        elif config.optim == "alcgd":
            torch.backends.cudnn.benchmark = True
            opt = ALCGD(
                max_params=train_taskset.parameters(),
                min_params=opt_params,
                lr_max=config.outer_lr,
                lr_min=config.outer_lr,
                device=DEVICE,
            )
        else:
            raise ValueError(
                f"Invalid option: {config.optim} for `config.optim`")

    best_dev_error = np.inf
    if args.load_from:
        state_obj = torch.load(os.path.join(args.load_from, "optim.th"))
        opt.load_state_dict(state_obj["optimizer"])
        num_epochs = num_epochs - state_obj["last_epoch"]
        (dev_task, _), _ = dev_taskset.sample(k=config.shots)
        dev_loader = DataLoader(
            data_utils.InnerDataset(dev_task),
            batch_size=task_bs,
            shuffle=False,
            num_workers=0,
        )
        dev_error, dev_metrics = utils.compute_loss_metrics(
            dev_loader,
            bert_model,
            clf_head,
            label_map,
            grad_required=False,
            return_metrics=False,
        )
        best_dev_error = dev_error.mean()

    def save_dist(name):
        save_dir = "./models/{}".format(utils.get_savedir_name())
        with open(os.path.join(save_dir, name), "wb") as f:
            np.save(f, train_taskset.tau.detach().cpu().numpy())

    patience_ctr = 0
    eval_freq = config.eval_freq // (config.inner_loop_steps + 1)
    patience_over = False
    constrain_loss_list = defaultdict(lambda: deque(maxlen=10))
    tqdm_bar = tqdm(range(num_epochs))
    for iteration in tqdm_bar:
        dev_iteration_error = 0.0
        train_iteration_error = 0.0
        meta_encoder.train()
        meta_clf.train()
        episode_iterator = tqdm(range(num_episodes), desc="Training")
        for episode_num in episode_iterator:
            learner = meta_clf.clone()
            encoder = meta_encoder.clone(
            ) if config.finetune_enc else meta_encoder
            (train_task,
             train_langs), imps = train_taskset.sample(k=config.shots)
            (dev_task, _), _ = dev_taskset.sample(k=config.shots,
                                                  langs=train_langs)

            for _ in range(inner_loop_steps):
                train_loader = DataLoader(
                    data_utils.InnerDataset(train_task),
                    batch_size=task_bs,
                    shuffle=True,
                    num_workers=0,
                )
                train_error, train_metrics = utils.compute_loss_metrics(
                    train_loader,
                    encoder,
                    learner,
                    label_map=label_map,
                    return_metrics=False,
                    enc_grad_required=config.finetune_enc,
                )
                train_error = train_error.mean()
                train_iteration_error += train_error.item()
                learner.adapt(train_error, retain_graph=config.finetune_enc)
                if config.finetune_enc:
                    encoder.adapt(train_error, allow_unused=True)

            dev_loader = DataLoader(
                data_utils.InnerDataset(dev_task),
                batch_size=task_bs,
                shuffle=True,
                num_workers=0,
            )
            dev_error, dev_metrics = utils.compute_loss_metrics(
                dev_loader,
                encoder,
                learner,
                label_map,
                return_metrics=False,
                enc_grad_required=config.finetune_enc,
            )
            if config.train_type == "minmax":
                dev_error *= imps
                dev_error = dev_error.sum()
            elif config.train_type == "constrain":
                constrain_val = config.constrain_val
                if (hasattr(config, "constrain_type")
                        and config.constrain_type == "dynamic"):
                    constrain_val = torch.tensor([
                        np.mean(constrain_loss_list[lang])
                        if len(constrain_loss_list[lang]) > 5 else
                        -config.constrain_val for lang in train_langs
                    ]).to(dev_error.device)
                    for loss_val, lang in zip(dev_error, train_langs):
                        constrain_loss_list[lang].append(loss_val.item())
                dev_error = (dev_error.mean() +
                             ((dev_error - constrain_val) * imps).sum())
            elif config.train_type == "metabase":
                dev_error = dev_error.mean()
            else:
                raise ValueError(
                    f"Invalid option: {config.train_type} for `config.train_type`"
                )

            if config.train_type == "metabase":
                dev_error.backward()
                opt.step()
            else:
                opt.step(loss=dev_error)
            opt.zero_grad()

            dev_iteration_error += dev_error.item()

            tb_writer.add_scalar("metrics/loss", dev_error,
                                 (iteration * num_epochs) + episode_num)
            if dev_metrics is not None:
                tb_writer.add_scalar(
                    "metrics/precision",
                    dev_metrics["precision"],
                    (iteration * num_epochs) + episode_num,
                )
                tb_writer.add_scalar(
                    "metrics/recall",
                    dev_metrics["recall"],
                    (iteration * num_epochs) + episode_num,
                )
                tb_writer.add_scalar(
                    "metrics/f1",
                    dev_metrics["f1"],
                    (iteration * num_epochs) + episode_num,
                )

            if episode_num and episode_num % eval_freq == 0:
                dev_iteration_error /= eval_freq
                train_iteration_error /= eval_freq * inner_loop_steps
                if dev_metrics is not None:
                    tqdm_bar.set_description(
                        "Train. Loss: {:.3f} Train F1: {:.3f} Dev. Loss: {:.3f} Dev. F1: {:.3f}"
                        .format(
                            train_iteration_error,
                            train_metrics["f1"],
                            dev_iteration_error,
                            dev_metrics["f1"],
                        ))
                else:
                    tqdm_bar.set_description(
                        "Train. Loss: {:.3f} Dev. Loss: {:.3f}".format(
                            train_iteration_error, dev_iteration_error))

                meta_clf.eval()
                meta_encoder.eval()
                eval_loss, _ = utils.compute_loss_metrics(
                    eval_loader,
                    meta_encoder,
                    meta_clf,
                    label_map,
                    grad_required=False,
                    return_metrics=False,
                )
                eval_error = eval_loss.mean()

                if eval_error < best_dev_error:
                    logger.info("Found new best model!")
                    best_dev_error = eval_error
                    save(meta_clf, opt, args.config_path, iteration,
                         meta_encoder if config.finetune_enc else None)
                    save_dist("best_dist.npy")
                    patience_ctr = 0
                else:
                    patience_ctr += 1
                    if patience_ctr == config.patience:
                        logger.info(
                            "Ran out of patience. Stopping training early...")
                        patience_over = True
                        break
                dev_iteration_error = 0.0
                train_iteration_error = 0.0

        if config.train_type != "metabase" and iteration % 10 == 0:
            save_dist("dist.npy")
        if patience_over:
            break

    logger.info(f"Best validation loss = {best_dev_error}")
    logger.info("Best model saved at: {}".format(utils.get_savedir_name()))
Пример #22
0
    def train(self, net, samples, optimizer, e):
        alpha = 2 * max(0, ((50 - e) / 50))
        criterion = losses.ELULovaszFocalWithLogitsLoss(alpha, 2 - alpha)

        transforms = generator.TransformationsGenerator([
            random.RandomFlipLr(),
            random.RandomAffine(image_size=101,
                                translation=lambda rs:
                                (rs.randint(-20, 20), rs.randint(-20, 20)),
                                scale=lambda rs: (rs.uniform(0.85, 1.15), 1),
                                **utils.transformations_options)
        ])

        samples_aux = list(
            set(samples).intersection(set(utils.get_aux_samples())))
        dataset_aux = datasets.ImageDataset(samples_aux, settings.train,
                                            transforms)

        dataset_pseudo = datasets.SemiSupervisedImageDataset(
            samples_test,
            settings.test,
            transforms,
            size=len(samples_test),
            test_predictions=self.test_predictions,
            momentum=0.0)

        dataset = datasets.ImageDataset(samples, settings.train, transforms)
        weight_train = len(dataset_pseudo) / len(dataset) * 2
        weight_aux = weight_train / 2
        weights = [weight_train] * len(dataset) + [weight_aux] * len(
            dataset_aux) + [1] * len(dataset_pseudo)
        dataloader = DataLoader(
            ConcatDataset([dataset, dataset_aux, dataset_pseudo]),
            num_workers=10,
            batch_size=16,
            sampler=WeightedRandomSampler(weights=weights, num_samples=3200))

        average_meter_train = meters.AverageMeter()

        with tqdm(total=len(dataloader), leave=False,
                  ascii=True) as pbar, torch.enable_grad():
            net.train()

            padding = tta.Pad((13, 14, 13, 14))

            for images, masks_targets in dataloader:
                masks_targets = masks_targets.to(gpu)
                masks_predictions = padding.transform_backward(
                    net(padding.transform_forward(images))).contiguous()

                loss = criterion(masks_predictions, masks_targets)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                average_meter_train.add('loss', loss.item())
                self.update_pbar(torch.sigmoid(masks_predictions),
                                 masks_targets, pbar, average_meter_train,
                                 'Training epoch {}'.format(e))

        train_stats = {
            'train_' + k: v
            for k, v in average_meter_train.get_all().items()
        }
        return train_stats
    csv_file = filenames[0].split('/')[-1]

    #Créer data set pour un csv file en particulier
    # essai=DoodlesDataset(csv_file, path,nrows=select_nrows, size=size_image,skiprows=range(1,10))

    # loader=DataLoader(essai,batch_size=10)
    # for image, label in loader:
    #     print(image)
    #     t1=image[0,0,:,:]
    #     #imshow(t1)
    #     print(label)

    doodles = ConcatDataset([
        DoodlesDataset(fn.split('/')[-1],
                       path,
                       nrows=select_nrows,
                       size=size_image) for fn in filenames
    ])

    loader = DataLoader(doodles, batch_size=2, shuffle=True)

    i = 0
    for image, label in loader:
        # print(image)
        t1 = image[0, 0, :, :]
        t2 = image[1, 0, :, :]
        # imshow(t1)
        # imshow(t2)
        i += 2
        print(i)
        print(label)
Пример #24
0
def main(config):
    opts = config()
    path = opts.path
    train = pd.read_csv(f'{path}/train.csv')
    pseudo_label = pd.read_csv(
        './submissions/submission_segmentation_and_classifier.csv')

    n_train = len(os.listdir(f'{path}/train_images'))
    n_test = len(os.listdir(f'{path}/test_images'))
    print(f'There are {n_train} images in train dataset')
    print(f'There are {n_test} images in test dataset')

    train.loc[train['EncodedPixels'].isnull() == False,
              'Image_Label'].apply(lambda x: x.split('_')[1]).value_counts()
    train.loc[train['EncodedPixels'].isnull() == False, 'Image_Label'].apply(
        lambda x: x.split('_')[0]).value_counts().value_counts()

    train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[1])
    train['im_id'] = train['Image_Label'].apply(lambda x: x.split('_')[0])
    id_mask_count = train.loc[train['EncodedPixels'].isnull() == False,
                              'Image_Label'].apply(lambda x: x.split('_')[
                                  0]).value_counts().reset_index().rename(
                                      columns={
                                          'index': 'img_id',
                                          'Image_Label': 'count'
                                      })
    print(id_mask_count.head())

    pseudo_label.loc[pseudo_label['EncodedPixels'].isnull() == False,
                     'Image_Label'].apply(
                         lambda x: x.split('_')[1]).value_counts()
    pseudo_label.loc[pseudo_label['EncodedPixels'].isnull() == False,
                     'Image_Label'].apply(lambda x: x.split('_')[0]
                                          ).value_counts().value_counts()

    pseudo_label['label'] = pseudo_label['Image_Label'].apply(
        lambda x: x.split('_')[1])
    pseudo_label['im_id'] = pseudo_label['Image_Label'].apply(
        lambda x: x.split('_')[0])
    pseudo_label_ids = pseudo_label.loc[
        pseudo_label['EncodedPixels'].isnull() == False, 'Image_Label'].apply(
            lambda x: x.split('_')[0]).value_counts().reset_index().rename(
                columns={
                    'index': 'img_id',
                    'Image_Label': 'count'
                })
    print(pseudo_label_ids.head())

    if not os.path.exists("csvs/train_all.csv"):
        train_ids, valid_ids = train_test_split(
            id_mask_count,
            random_state=39,
            stratify=id_mask_count['count'],
            test_size=0.1)
        valid_ids.to_csv("csvs/valid_threshold.csv", index=False)
        train_ids.to_csv("csvs/train_all.csv", index=False)
    else:
        train_ids = pd.read_csv("csvs/train_all.csv")
        valid_ids = pd.read_csv("csvs/valid_threshold.csv")

    for fold, ((train_ids_new, valid_ids_new),
               (train_ids_pl, valid_ids_pl)) in enumerate(
                   zip(
                       stratified_groups_kfold(train_ids,
                                               target='count',
                                               n_splits=opts.fold_max,
                                               random_state=0),
                       stratified_groups_kfold(pseudo_label_ids,
                                               target='count',
                                               n_splits=opts.fold_max,
                                               random_state=0))):

        train_ids_new.to_csv(f'csvs/train_fold{fold}.csv')
        valid_ids_new.to_csv(f'csvs/valid_fold{fold}.csv')
        train_ids_new = train_ids_new['img_id'].values
        valid_ids_new = valid_ids_new['img_id'].values

        train_ids_pl = train_ids_pl['img_id'].values
        valid_ids_pl = valid_ids_pl['img_id'].values

        ENCODER = opts.backborn
        ENCODER_WEIGHTS = opts.encoder_weights
        DEVICE = 'cuda'

        ACTIVATION = None
        model = get_model(
            model_type=opts.model_type,
            encoder=ENCODER,
            encoder_weights=ENCODER_WEIGHTS,
            activation=ACTIVATION,
            n_classes=opts.class_num,
            task=opts.task,
            center=opts.center,
            attention_type=opts.attention_type,
            head='simple',
            classification=opts.classification,
        )
        model = convert_model(model)
        preprocessing_fn = encoders.get_preprocessing_fn(
            ENCODER, ENCODER_WEIGHTS)

        num_workers = opts.num_workers
        bs = opts.batchsize

        train_dataset = CloudDataset(
            df=train,
            label_smoothing_eps=opts.label_smoothing_eps,
            datatype='train',
            img_ids=train_ids_new,
            transforms=get_training_augmentation(opts.img_size),
            preprocessing=get_preprocessing(preprocessing_fn))
        valid_dataset = CloudDataset(
            df=train,
            datatype='valid',
            img_ids=valid_ids_new,
            transforms=get_validation_augmentation(opts.img_size),
            preprocessing=get_preprocessing(preprocessing_fn))

        ################# make pseudo label dataset #######################
        train_dataset_pl = CloudPseudoLabelDataset(
            df=pseudo_label,
            datatype='train',
            img_ids=train_ids_pl,
            transforms=get_training_augmentation(opts.img_size),
            preprocessing=get_preprocessing(preprocessing_fn))
        valid_dataset_pl = CloudPseudoLabelDataset(
            df=pseudo_label,
            datatype='train',
            img_ids=valid_ids_pl,
            transforms=get_validation_augmentation(opts.img_size),
            preprocessing=get_preprocessing(preprocessing_fn))

        #         train_dataset = ConcatDataset([train_dataset, train_dataset_pl])
        #         valid_dataset = ConcatDataset([valid_dataset, valid_dataset_pl])
        train_dataset = ConcatDataset([train_dataset, valid_dataset_pl])
        ################# make pseudo label dataset #######################
        train_loader = DataLoader(train_dataset,
                                  batch_size=bs,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  drop_last=True)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=bs,
                                  shuffle=False,
                                  num_workers=num_workers,
                                  drop_last=True)

        loaders = {"train": train_loader, "valid": valid_loader}
        num_epochs = opts.max_epoch
        logdir = f"{opts.logdir}/fold{fold}"
        optimizer = get_optimizer(optimizer=opts.optimizer,
                                  lookahead=opts.lookahead,
                                  model=model,
                                  separate_decoder=True,
                                  lr=opts.lr,
                                  lr_e=opts.lr_e)
        opt_level = 'O1'
        model.cuda()
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=opt_level)
        scheduler = opts.scheduler(optimizer)
        criterion = opts.criterion
        runner = SupervisedRunner()
        if opts.task == "segmentation":
            callbacks = [DiceCallback()]
        else:
            callbacks = []
        if opts.early_stop:
            callbacks.append(
                EarlyStoppingCallback(patience=10, min_delta=0.001))
        if opts.mixup:
            callbacks.append(MixupCallback(alpha=0.25))
        if opts.accumeration is not None:
            callbacks.append(CriterionCallback())
            callbacks.append(
                OptimizerCallback(accumulation_steps=opts.accumeration))
        print(
            f"############################## Start training of fold{fold}! ##############################"
        )
        runner.train(model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     loaders=loaders,
                     callbacks=callbacks,
                     logdir=logdir,
                     num_epochs=num_epochs,
                     verbose=True)
        print(
            f"############################## Finish training of fold{fold}! ##############################"
        )
        del model
        del loaders
        del runner
        torch.cuda.empty_cache()
        gc.collect()
Пример #25
0
def optimization_function(input_arguments):
    next_input_arguments_index = 0

    all_datasets = []

    wandb.init(project="auto_augment")
    wandb_name = 'test_sst2_val_5_hyperopt_'

    model_name = 'bert-base-uncased'
    dataset_identifier = 'val_5'

    for i in range(args.sub_policies_per_policy):
        function_name = functions[int(
            input_arguments[next_input_arguments_index])]
        probability_of_application = input_arguments[next_input_arguments_index
                                                     + 1]
        argument_1 = input_arguments[next_input_arguments_index + 2]
        argument_2 = input_arguments[next_input_arguments_index + 3]
        wandb_name += function_name + "_" + str(
            probability_of_application) + "_" + str(argument_1) + "_" + str(
                argument_2) + "_"
        next_input_arguments_index += 4

        val = pickle.load(open('../data/sst2/sst2_10_samples_val.pkl', 'rb'))

        if function_name in ['random', 'intra_lada']:
            # 0 arguments
            mix = function_name_to_tmix_functions_map[function_name]
            val_dataset = create_dataset(
                val['X'],
                val['y'],
                model_name,
                256,
                mix=mix,
                num_classes=2,
                probability_of_application=probability_of_application,
                dataset_identifier=dataset_identifier)
        elif function_name in [
                'synonym_replacement', 'random_insert', 'random_swap',
                'random_delete'
        ]:
            # one argument
            mix = function_name_to_tmix_functions_map[function_name]
            val_dataset = create_dataset(
                val['X'],
                val['y'],
                model_name,
                256,
                mix=mix,
                num_classes=2,
                probability_of_application=probability_of_application,
                alpha=argument_1,
                dataset_identifier=dataset_identifier)
        elif function_name == 'inter_lada':
            # 2 arguments
            knn = int(argument_1 * 10)
            mu = argument_2
            val_dataset = create_dataset(
                val['X'],
                val['y'],
                model_name,
                256,
                mix='Inter_LADA',
                num_classes=2,
                probability_of_application=probability_of_application,
                knn_lada=knn,
                mu_lada=mu,
                dataset_identifier=dataset_identifier)
        else:
            assert False

        all_datasets.append(val_dataset)

    wandb.run.name = wandb_name
    wandb.run.save()

    val_dataset_combined = ConcatDataset(all_datasets)
    val_dataloader = DataLoader(val_dataset_combined,
                                batch_size=args.batch_size,
                                num_workers=4)

    base_model = torch.load(args.checkpoint_path).cuda()
    base_model.eval()

    with torch.no_grad():
        loss_total = 0
        total_sample = 0

        for batch in tqdm(val_dataloader):
            encoded_1, encoded_2, label_1, label_2 = batch
            assert encoded_1.shape == encoded_2.shape

            mix_layer = np.random.choice(args.mix_layers)
            l = np.random.beta(args.alpha, args.alpha)
            l = max(l, 1 - l)

            logits = base_model(encoded_1.cuda(), encoded_2.cuda(), l,
                                mix_layer)
            combined_labels = label_1 * l + label_2 * (1 - l)
            loss = own_loss(logits, combined_labels.cuda(), num_labels=10)
            loss_total += loss.item() * encoded_1.shape[0]
            total_sample += encoded_1.shape[0]

        wandb.log({'Test loss': loss_total})
        return loss_total
device = "cuda" if torch.cuda.is_available() else "cpu"

model = Classifier().to(device)
model.device = device

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-5)

n_epochs = 100

for epoch in range(n_epochs):

    if epoch >= 0:
        new_set = get_pseudo_label(train_loader, unlabeled_set, model)
        concat_dataset = ConcatDataset([train_set, new_set])

        train_loader = DataLoader(concat_dataset, batch_size=10)

    model.train()

    train_loss = []
    train_accs = []

    for batch in tqdm(train_loader):
        imgs, labels = batch

        logits = model(imgs.to(device))

        loss = criterion(logits, labels.to(device))
Пример #27
0
 def test_concat_two_non_singletons(self):
     result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
     self.assertEqual(10, len(result))
     self.assertEqual(0, result[0])
     self.assertEqual(5, result[5])
Пример #28
0
    def load_dataset(self, split, combine=False, shuffle=True):
        """Load a dataset split."""

        def split_exists(split, src, tgt, lang):
            filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
            elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                return IndexedRawTextDataset(path, dictionary)
            elif IndexedInMemoryDataset.exists(path):
                return IndexedInMemoryDataset(path, fix_lua_indexing=True)
            return None

        src_datasets = []
        tgt_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')

            # infer langcode
            src, tgt = self.args.source_lang, self.args.target_lang
            if split_exists(split_k, src, tgt, src):
                prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, src, tgt))
            elif split_exists(split_k, tgt, src, src):
                prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, tgt, src))
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))

            src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
            tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))

            print('| {} {} {} examples'.format(self.args.data, split_k, len(src_datasets[-1])))

            if not combine:
                break

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
            src_sizes = src_dataset.sizes
            tgt_sizes = tgt_dataset.sizes
        else:
            src_dataset = ConcatDataset(src_datasets)
            tgt_dataset = ConcatDataset(tgt_datasets)
            src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
            tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])

        self.datasets[split] = LanguagePairDataset(
            src_dataset, src_sizes, self.src_dict,
            tgt_dataset, tgt_sizes, self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            shuffle=shuffle,
        )
Пример #29
0
 def test_concat_raises_index_error(self):
     result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
     with self.assertRaises(IndexError):
         # this one goes to 11
         result[11]
        verbose=False,
    )

    val_dataset = CSGODataset(
        transform=transform_multichannel,
        dataset_split='val',
        verbose=False,
    )

    test_dataset = CSGODataset(
        transform=transform_multichannel,
        dataset_split='test',
        verbose=False,
    )

    train_val_dataset = ConcatDataset([train_dataset, val_dataset])

    # implicit else
    train_loader = torch.utils.data.DataLoader(
        train_val_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=0,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=64,
        shuffle=False,
        num_workers=0,
    )