Exemplo n.º 1
0
def run(config, train_system=True, only_train_replay=False, train_tandem=True):
    """ Method to start training MNIST replay model. 
    Depending on the configurations, here we control the creation and 
    training of the different replay modules with their corresponding 
    hypernetworks.
        
    Args:
        config: The command line arguments.
        train_system: (optional) Set to false if we want this function 
            only to create config, networks and data_handlers for future 
            training. See :func:`mnist.train_splitMNIST.run` for a use case.
        only_train_replay: (optional) If this script will only be used to 
            train a replay model. Normally, we use this script in tandem 
            with an additional classifier that uses this replay model to 
            replay old tasks data.
        train_tandem: (optional) If we will use this script to train in 
            tandem i.e. in an alternating fashion with a classifier.
    Returns:
        (tuple): Tuple containing:
        (....): See docstring of function :func:`train`.
    """

    # if we want to train a classifier on single classes then we need a single
    # class replay method. This need not be the case otherwise i.e. we can 
    # have a single class replay method but train our classifier on the 
    # replay data (build out of multiple replayed conidtions) and the current
    # data at once.
    # single class replay only implemented for splitMNIST
    if config.single_class_replay:
        assert (config.experiment == "splitMNIST")

    if config.num_tasks > 100 and config.cl_scenario != 1:
        print("Attention: Replay model not tested for num tasks > 100")

    ### Setup environment
    device, writer = train_utils._setup_environment(config)

    ### Create tasks for split MNIST
    if config.single_class_replay:
        steps = 1
    else:
        steps = 2

    ### Create tasks for split MNIST
    if train_system == False and config.upper_bound == False:
        dhandlers = None
    else:
        dhandlers = train_utils._generate_tasks(config, steps)

    ### Generate networks.
    if train_system == False:
        enc, dec, d_hnet = None, None, None
    else:
        if config.rp_beta > 0:
            create_rp_hnet = True
        else:
            create_rp_hnet = False
        enc, dec, d_hnet = train_utils_replay.generate_replay_networks(config,
                                                                       dhandlers, device, create_rp_hnet,
                                                                       only_train_replay=only_train_replay)
        ### Generate task prioirs for latent space.
        priors = []
        test_z = []
        vae_conds = []

        ### Save some noise vectors for testing

        for t in range(config.num_embeddings):
            # if conditional prior create some task priors and save them
            if config.conditional_prior:
                mu = torch.zeros((config.latent_dim)).to(device)
                nn.init.normal_(mu, mean=0, std=1.)
                mu = torch.stack([mu] * config.batch_size)
                mu.requires_grad = False
                priors.append(mu)
            else:
                mu = torch.zeros((config.batch_size,
                                  config.latent_dim)).to(device)
                priors.append(None)

            ### Generate sampler for latent space.
            eps = torch.randn_like(mu)

            sample = mu + eps
            sample.requires_grad = False
            test_z.append(sample)

            # if vae has some conditional input, then either save hot-encodings
            # or some conditions from a gaussian
            if config.conditional_replay:
                vae_c = torch.zeros((config.conditional_dim)).to(device)
                if not config.not_conditional_hot_enc:
                    vae_c[t] = 1
                else:
                    nn.init.normal_(vae_c, mean=0, std=1.)
                vae_c = torch.stack([vae_c] * config.batch_size)
                vae_c.requires_grad = False
                vae_conds.append(vae_c)

        config.test_z = test_z
        config.priors = priors
        config.vae_conds = vae_conds
        if not train_tandem:
            ### Train the network.
            train(dhandlers, enc, dec, d_hnet, device, config, writer)

            ### Test network.
            test(enc, dec, d_hnet, device, config, writer)

    return dec, d_hnet, enc, dhandlers, device, writer, config
Exemplo n.º 2
0
def run(mode='split'):
    """ Method to start MNIST experiments. 
    Depending on the configurations, here we control the creation and 
    training of the different (replay) modules for classification or 
    task inference build out of standart neural networks and their 
    corresponding hypernetworks.

    Args:
        mode (str): Training mode defines which experiments and default values 
        are loaded. Options are splitMNIST or permutedMNIST:

                - ``split``
                - ``perm``
    """

    ### Get command line arguments.
    config = train_args.parse_cmd_arguments(mode=mode)

    assert (config.experiment == "splitMNIST" or \
            config.experiment == "permutedMNIST")
    if not config.dont_set_default:
        config = _set_default(config)

    if config.infer_output_head:
        assert (config.infer_task_id == True)

    if config.cl_scenario == 1:
        assert (config.class_incremental == False)
        assert (config.single_class_replay == False)

    if config.infer_with_entropy:
        assert (config.infer_task_id == True)
    # single class only implemented for splitMNIST
    if config.single_class_replay or config.class_incremental:
        assert (config.experiment == "splitMNIST")

    # check range of number of tasks
    assert (config.num_tasks > 0)
    if config.experiment == "splitMNIST":
        if config.class_incremental:
            assert (config.num_tasks <= 10)
        else:
            assert (config.num_tasks <= 5)

    # the following combination is not supported 
    if config.infer_task_id:
        assert (config.class_incremental == False)

    # enforce correct cl scenario
    if config.class_incremental:
        config.single_class_replay = 1
        config.cl_scenario = 3
        print("Attention: Cl scenario 3 is enforced!")
        steps = 1
    else:
        steps = 2

        #### Get data handlers
    dhandlers_class = train_utils._generate_tasks(config, steps)

    # decide if you want to train a replay model
    # in the case where you only want a classifier and you know the task id
    # we only train a classifier + hnet. Upper bound considers the replay case 
    # but you replay real data as if the replayu model would be "perfect".
    if config.upper_bound or (config.infer_task_id and config.cl_scenario == 1):
        train_rp = False
    else:
        train_rp = True

    ### Get replay model trained continually with hnet.
    dec, d_hnet, enc, dhandlers_rp, device, writer, config = \
        replay_model(config, train_rp)

    # if we have a replay model trained, we now train a classifier
    # that either solves a task directly (HNET+replay) or we train a model
    # that infers the task from input.

    ###############################
    # Train task inference network
    ###############################

    if config.infer_task_id and not config.cl_scenario == 1 and \
            not config.infer_with_entropy:
        print("Training task inference model ...")
        config.trained_replay_model = False
        config.training_task_infer = True
        config.training_with_hnet = False
        ### Generate task inference network.
        infer_net = train_utils.generate_classifier(config,
                                                    dhandlers_class, device)

        ### Train the task inference network.
        config.during_accs_inference = train_tasks(dhandlers_class,
                                                   dhandlers_rp, enc, dec, d_hnet, None,
                                                   device, config, writer, infer_net=infer_net)
        ### Test network.
        print("Testing task inference model ...")
        test(dhandlers_class, None, infer_net, device, config, writer)
        config.training_with_hnet = True
        config.trained_replay_model = True
    else:
        # if we do not train an inference network we just train a model 
        # that knows it all and not
        infer_net = None
        if config.infer_with_entropy:
            config.trained_replay_model = True
        else:
            config.trained_replay_model = False

    if config.infer_task_id:
        config.training_with_hnet = True
    else:
        config.training_with_hnet = False

    ###################
    # Train classifier
    ###################

    config.training_task_infer = False

    print("Training final classifier ...")
    ### Generate another classifier network.
    class_nets = train_utils.generate_classifier(config,
                                                 dhandlers_class, device)
    ### Train the network.
    config.during_accs_final = train_tasks(dhandlers_class, dhandlers_rp, enc,
                                           dec, d_hnet, class_nets, device, config, writer, infer_net)

    print("Testing final classifier ...")
    ### Test network.
    test(dhandlers_class, class_nets, infer_net, device, config, writer)

    _save_performance_summary(config)
    writer.close()

    print('Program finished successfully.')