Пример #1
0
def main(args):
    """
    run inference for SS-VAE
    :param args: arguments for SS-VAE
    :return: None
    """
    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    # batch_size: number of images (and labels) to be considered in a batch
    ss_vae = SSVAE(z_dim=args.z_dim,
                   hidden_layers=args.hidden_layers,
                   use_cuda=args.cuda,
                   config_enum=args.enum_discrete,
                   aux_loss_multiplier=args.aux_loss_multiplier)

    # setup the optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)}
    optimizer = Adam(adam_params)

    # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
    # by enumerating each class label for the sampled discrete categorical distribution in the model
    guide_enum = config_enumerate(guide, args.enum_discrete, expand=True)
    elbo = (JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO)(max_plate_nesting=1)
    loss_basic = SVI(model, guide_enum, optimizer, loss=elbo)

    # build a list of all losses considered
    losses = [loss_basic]

    # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
    if args.aux_loss:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        loss_aux = SVI(ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo)
        losses.append(loss_aux)

    try:
        # setup the logger if a filename is provided
        logger = open(args.logfile, "w") if args.logfile else None

        data_loaders = setup_data_loaders(MNISTCached, args.cuda, args.batch_size, sup_num=args.sup_num)

        # how often would a supervised batch be encountered during inference
        # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
        # until we have traversed through the all supervised batches
        periodic_interval_batches = int(MNISTCached.train_data_size / (1.0 * args.sup_num))

        # number of unsupervised examples
        unsup_num = MNISTCached.train_data_size - args.sup_num

        # initializing local variables to maintain the best validation accuracy
        # seen across epochs over the supervised training set
        # and the corresponding testing set and the state of the networks
        best_valid_acc, corresponding_test_acc = 0.0, 0.0

        # WL: added. =====
        print_and_log(logger, args)
        print_and_log(logger, "\nepoch\t"+"elbo(sup)\t"+"elbo(unsup)\t"+"time(sec)")
        times = [time.time()]
        # ================

        # run inference for a certain number of epochs
        for i in range(0, args.num_epochs):

            # get the losses for an epoch
            epoch_losses_sup, epoch_losses_unsup = \
                run_inference_for_epoch(data_loaders, losses, periodic_interval_batches)

            # compute average epoch losses i.e. losses per example
            avg_epoch_losses_sup = map(lambda v: v / args.sup_num, epoch_losses_sup)
            avg_epoch_losses_unsup = map(lambda v: v / unsup_num, epoch_losses_unsup)

            # store the loss and validation/testing accuracies in the logfile
            # WL: edited. =====
            # str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
            # str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))

            # str_print = "{} epoch: avg losses {}".format(i, "{} {}".format(str_loss_sup, str_loss_unsup))
            times.append(time.time())
            str_elbo_sup = " ".join(map(lambda v: f"{-v:.4f}", avg_epoch_losses_sup))
            str_elbo_unsup = " ".join(map(lambda v: f"{-v:.4f}", avg_epoch_losses_unsup))
            str_print = f"{i:06d}\t"\
                        f"{str_elbo_sup}\t"\
                        f"{str_elbo_unsup}\t"\
                        f"{times[-1]-times[-2]:.3f}"
            # =================

            validation_accuracy = get_accuracy(data_loaders["valid"], ss_vae.classifier, args.batch_size)
            # WL: commented. =====
            # str_print += " validation accuracy {}".format(validation_accuracy)
            # ====================

            # this test accuracy is only for logging, this is not used
            # to make any decisions during training
            test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size)
            # WL: commented. =====
            # str_print += " test accuracy {}".format(test_accuracy)
            # ====================

            # update the best validation accuracy and the corresponding
            # testing accuracy and the state of the parent module (including the networks)
            if best_valid_acc < validation_accuracy:
                best_valid_acc = validation_accuracy
                corresponding_test_acc = test_accuracy

            print_and_log(logger, str_print)

        final_test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, args.batch_size)
        # WL: commented. =====
        # print_and_log(logger, "best validation accuracy {} corresponding testing accuracy {} "
        #               "last testing accuracy {}".format(best_valid_acc, corresponding_test_acc, final_test_accuracy))
        # ====================

    finally:
        # close the logger file object if we opened it earlier
        if args.logfile:
            logger.close()
Пример #2
0
def main(args):
    """
    run inference for CVAE
    :param args: arguments for CVAE
    :return: None
    """
    if args.seed is not None:
        set_seed(args.seed, args.cuda)

    if os.path.exists('cvae.model.pt'):
        print('Loading model %s' % 'cvae.model.pt')
        cvae = torch.load('cvae.model.pt')

    else:

        cvae = CVAE(z_dim=args.z_dim,
                    y_dim=8,
                    x_dim=32612,
                    hidden_dim=args.hidden_dimension,
                    use_cuda=args.cuda)

    print(cvae)

    # setup the optimizer
    adam_params = {
        "lr": args.learning_rate,
        "betas": (args.beta_1, 0.999),
        "clip_norm": 0.5
    }
    optimizer = ClippedAdam(adam_params)
    guide = config_enumerate(cvae.guide, args.enum_discrete)

    # set up the loss for inference.
    loss = SVI(cvae.model,
               guide,
               optimizer,
               loss=TraceEnum_ELBO(max_iarange_nesting=1))

    try:
        # setup the logger if a filename is provided
        logger = open(args.logfile, "w") if args.logfile else None

        data_loaders = setup_data_loaders(NHANES, args.cuda, args.batch_size)
        print(len(data_loaders['train']))
        print(len(data_loaders['test']))
        print(len(data_loaders['valid']))

        # initializing local variables to maintain the best validation acc
        # seen across epochs over the supervised training set
        # and the corresponding testing set and the state of the networks
        best_valid_err, best_test_err = float('inf'), float('inf')

        # run inference for a certain number of epochs
        for i in range(0, args.num_epochs):

            # get the losses for an epoch
            epoch_losses = \
                run_inference_for_epoch(args.batch_size, data_loaders, loss, args.cuda)

            # compute average epoch losses i.e. losses per example
            avg_epoch_losses = epoch_losses / NHANES.train_size

            # store the losses in the logfile
            str_loss = str(avg_epoch_losses)

            str_print = "{} epoch: avg loss {}".format(i,
                                                       "{}".format(str_loss))

            validation_err = get_accuracy(data_loaders["valid"],
                                          cvae.sim_measurements)
            str_print += " validation error {}".format(validation_err)

            # this test accuracy is only for logging, this is not used
            # to make any decisions during training
            test_еrr = get_accuracy(data_loaders["test"],
                                    cvae.sim_measurements)
            str_print += " test error {}".format(test_еrr)

            # update the best validation accuracy and the corresponding
            # testing accuracy and the state of the parent module (including the networks)
            if best_valid_err > validation_err:
                best_valid_err = validation_err
            if best_test_err > test_еrr:
                best_test_err = test_еrr

            print_and_log(logger, str_print)

        final_test_accuracy = get_accuracy(data_loaders["test"],
                                           cvae.sim_measurements)

        print_and_log(
            logger, "best validation error {} corresponding testing error {} "
            "last testing error {}".format(best_valid_err, best_test_err,
                                           final_test_accuracy))
        torch.save(cvae, 'cvae.model.pt')

        #mu, sigma, actuals, lods, masks = get_predictions(data_loaders["prediction"], cvae.sim_measurements)

        #torch.save((mu, sigma, actuals, lods, masks), 'cvae.predictions.pt')

    finally:
        # close the logger file object if we opened it earlier
        if args.logfile:
            logger.close()
Пример #3
0
def main():

    with WUST_TRAIN_PATH.open('rb') as file:
        wust_train = pickle.load(file)

    with WUST_TEST_PATH.open('rb') as file:
        wust_test = pickle.load(file)

    with WEAK_SMALL_ENCODED_PATH.open('rb') as file:
        weak_small_x = pickle.load(file)
        weak_small_x = torch.FloatTensor(weak_small_x)

    wust_train_x = torch.FloatTensor(wust_train[1])
    wust_test_x = torch.FloatTensor(wust_test[1])

    wust_train_labels = wust_train[2]
    wust_test_labels = wust_test[2]

    label_encoder = LabelEncoder()
    label_encoder.fit(wust_train_labels)

    wust_train_y = label_encoder.transform(wust_train_labels)
    wust_test_y = label_encoder.transform(wust_test_labels)
    wust_train_y = torch.LongTensor(wust_train_y.reshape(-1, 1))
    wust_test_y = torch.LongTensor(wust_test_y.reshape(-1, 1))

    wust_train_ds = TensorDataset(wust_train_x, wust_train_y)
    wust_train_dl = DataLoader(wust_train_ds,
                               batch_size=BATCH_SIZE,
                               shuffle=True)

    wust_test_ds = TensorDataset(wust_test_x, wust_test_y)
    wust_test_dl = DataLoader(wust_test_ds,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    weak_train_ds = TensorDataset(weak_small_x)
    weak_train_dl = DataLoader(weak_train_ds,
                               batch_size=BATCH_SIZE,
                               shuffle=True)

    pyro.set_rng_seed(SEED)

    # batch_size: number of images (and labels) to be considered in a batch
    ss_vae = SSVAE(input_size=wust_train_x.shape[1],
                   output_size=len(label_encoder.classes_),
                   z_dim=Z_DIM,
                   hidden_layers=HIDDEN_LAYERS,
                   use_cuda=CUDA,
                   config_enum=ENUM_DISCRETE,
                   aux_loss_multiplier=AUX_LOSS_MULTIPLIER)

    # setup the optimizer
    optimizer = Adam({"lr": LEARNING_RATE, "betas": (BETA_1, 0.999)})

    # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
    # by enumerating each class label for the sampled discrete categorical distribution in the model
    guide = config_enumerate(ss_vae.guide, ENUM_DISCRETE, expand=True)
    Elbo = JitTraceEnum_ELBO if USE_JIT else TraceEnum_ELBO
    elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
    loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo)

    # build a list of all losses considered
    losses = [loss_basic]

    # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
    if AUX_LOSS:
        elbo = JitTrace_ELBO() if USE_JIT else Trace_ELBO()
        loss_aux = SVI(ss_vae.model_classify,
                       ss_vae.guide_classify,
                       optimizer,
                       loss=elbo)
        losses.append(loss_aux)

    try:
        # setup the logger if a filename is provided
        logger = open(LOGFILE, "w") if LOGFILE else None

        # data_loaders = setup_data_loaders(MNISTCached, CUDA, BATCH_SIZE, sup_num=SUP_NUM)

        # how often would a supervised batch be encountered during inference
        # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
        # until we have traversed through the all supervised batches

        # number of unsupervised examples
        sup_num = len(wust_train_ds)
        unsup_num = len(weak_train_ds)
        periodic_interval_batches = int(unsup_num / sup_num)

        # initializing local variables to maintain the best validation accuracy
        # seen across epochs over the supervised training set
        # and the corresponding testing set and the state of the networks
        best_valid_acc, corresponding_test_acc = 0.0, 0.0

        # run inference for a certain number of epochs
        for i in range(0, NUM_EPOCHS):

            # get the losses for an epoch
            epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch(
                wust_train_dl, weak_train_dl, losses,
                periodic_interval_batches)

            # compute average epoch losses i.e. losses per example
            avg_epoch_losses_sup = map(lambda v: v / sup_num, epoch_losses_sup)
            avg_epoch_losses_unsup = map(lambda v: v / unsup_num,
                                         epoch_losses_unsup)

            # store the loss and validation/testing accuracies in the logfile
            str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
            str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))

            str_print = f"{i} epoch: avg losses {str_loss_sup} {str_loss_unsup}"

            # validation_accuracy = get_accuracy(data_loaders["valid"], ss_vae.classifier, BATCH_SIZE)
            # str_print += " validation accuracy {}".format(validation_accuracy)

            # this test accuracy is only for logging, this is not used
            # to make any decisions during training
            test_accuracy = get_accuracy(wust_test_dl, ss_vae.classifier,
                                         BATCH_SIZE)
            str_print += " test accuracy {}".format(test_accuracy)

            # update the best validation accuracy and the corresponding
            # testing accuracy and the state of the parent module (including the networks)
            if best_valid_acc < test_accuracy:
                best_valid_acc = test_accuracy
                corresponding_test_acc = test_accuracy

            print_and_log(logger, str_print)

        final_test_accuracy = get_accuracy(wust_test_dl, ss_vae.classifier,
                                           BATCH_SIZE)
        print_and_log(
            logger,
            f"best validation accuracy {best_valid_acc} corresponding testing accuracy {corresponding_test_acc} "
            f"last testing accuracy {final_test_accuracy}")

    finally:
        if LOGFILE:
            logger.close()
Пример #4
0
def main():
    """
    run inference for SS-VAE
    :param args: arguments for SS-VAE
    :return: None
    """
    pyro.set_rng_seed(12345)
    cuda = True
    # batch_size: number of images (and labels) to be considered in a batch
    ss_vae = TextSSVAE(embed_dim=300,
                       z_dim=300,
                       kernels=[3, 4, 5],
                       filters=[100, 100, 100],
                       hidden_size=300,
                       num_rnn_layers=1,
                       config_enum="parallel",
                       use_cuda=cuda,
                       aux_loss_multiplier=46)

    ss_vae = ss_vae.cuda()

    try:
        pyro.get_param_store().load('pyro_param_store.store')
        print(
            'successfully loaded param store, remove file from directory if undesired'
        )
    except Exception:
        print("failed to load param store, starting over")

    try:
        ss_vae.load_state_dict(torch.load('ss_vae_model.pth'))
        print(
            'successfully loaded model parameters, remove file from directory if undesired'
        )
    except Exception:
        print("failed to load model parameters")

    # setup the optimizer
    adam_params = {"lr": 1e-4, "betas": (0.9, 0.999), "weight_decay": 0.01}
    optimizer = Adam(adam_params)

    # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
    # by enumerating each class label for the sampled discrete categorical distribution in the model
    jit = False
    guide = config_enumerate(ss_vae.guide, "parallel", expand=True)
    elbo = (JitTraceEnum_ELBO if jit else TraceEnum_ELBO)()

    loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo)

    # build a list of all losses considered
    losses = [loss_basic]

    # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
    aux_loss = True
    if aux_loss:
        elbo = JitTrace_ELBO() if jit else Trace_ELBO()
        loss_aux = SVI(ss_vae.model_classify,
                       ss_vae.guide_classify,
                       optimizer,
                       loss=elbo)
        losses.append(loss_aux)

    batch_size = 32
    valid_num = 100
    train_data_size = 3409
    sup_num = 1163
    try:
        # setup the logger if a filename is provided
        logger = open('./tmp.log', "w") if './tmp.log' else None
        data_loaders = setup_data_loaders(IMDBCached,
                                          cuda,
                                          batch_size=32,
                                          sup_num=valid_num)

        # how often would a supervised batch be encountered during inference
        # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
        # until we have traversed through the all supervised batches
        periodic_interval_batches = int(train_data_size / (1.0 * sup_num))

        # number of unsupervised examples
        unsup_num = train_data_size - sup_num

        # initializing local variables to maintain the best validation accuracy
        # seen across epochs over the supervised training set
        # and the corresponding testing set and the state of the networks
        best_valid_acc, corresponding_test_acc = 0.0, 0.0

        # run inference for a certain number of epochs
        num_epochs = 200
        sup_loss_log = []
        unsup_loss_log = []

        for i in range(0, num_epochs):
            # get the losses for an epoch
            epoch_losses_sup, epoch_losses_unsup = \
                run_inference_for_epoch(data_loaders, losses, periodic_interval_batches)

            # compute average epoch losses i.e. losses per example
            avg_epoch_losses_sup = map(lambda v: v / sup_num, epoch_losses_sup)
            avg_epoch_losses_unsup = map(lambda v: v / unsup_num,
                                         epoch_losses_unsup)

            sup_loss_log.append(avg_epoch_losses_sup)
            unsup_loss_log.append(avg_epoch_losses_unsup)

            # store the loss and validation/testing accuracies in the logfile
            str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
            str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))

            str_print = "{} epoch: avg losses {}".format(
                i, "{} {}".format(str_loss_sup, str_loss_unsup))
            ss_vae.eval()
            validation_accuracy = get_accuracy(data_loaders["valid"],
                                               ss_vae.classifier, batch_size)
            str_print += " validation accuracy {}".format(validation_accuracy)

            # this test accuracy is only for logging, this is not used
            # to make any decisions during training
            test_accuracy = get_accuracy(data_loaders["test"],
                                         ss_vae.classifier, batch_size)
            str_print += " test accuracy {}".format(test_accuracy)
            ss_vae.train()
            torch.save(ss_vae.state_dict(), 'ss_vae_model.pth')
            pyro.get_param_store().save('pyro_param_store.store')

            # update the best validation accuracy and the corresponding
            # testing accuracy and the state of the parent module (including the networks)
            if best_valid_acc < validation_accuracy:
                best_valid_acc = validation_accuracy
                corresponding_test_acc = test_accuracy
            if i % 10 == 0:
                neg_sentences, neg_bleu = generateSentences(
                    data_loaders["test"],
                    ss_vae.model,
                    ss_vae.w2v_model,
                    sentiment=0)
                pos_sentences, pos_bleu = generateSentences(
                    data_loaders["test"],
                    ss_vae.model,
                    ss_vae.w2v_model,
                    sentiment=1)
                str_print += " neg_bleu {}".format(neg_bleu)
                str_print += " pos_bleu {}".format(pos_bleu)
                pd.DataFrame.from_dict(pos_sentences).to_csv(
                    'positive_sentences.csv', encoding='utf-8')
                pd.DataFrame.from_dict(neg_sentences).to_csv(
                    'negative_sentences.csv', encoding='utf-8')

                cond_neg_sentences, neg_bleu = generateSentences(
                    data_loaders["test"],
                    ss_vae.conditioned_generation,
                    ss_vae.w2v_model,
                    sentiment=0)
                cond_pos_sentences, pos_bleu = generateSentences(
                    data_loaders["test"],
                    ss_vae.conditioned_generation,
                    ss_vae.w2v_model,
                    sentiment=1)
                pd.DataFrame.from_dict(cond_pos_sentences).to_csv(
                    'cond_positive_sentences.csv', encoding='utf-8')
                pd.DataFrame.from_dict(cond_neg_sentences).to_csv(
                    'cond_negative_sentences.csv', encoding='utf-8')
                str_print += " cond_neg_bleu {}".format(neg_bleu)
                str_print += " cond_pos_bleu {}".format(pos_bleu)

            print_and_log(logger, str_print)

        np.save("avg_loss_sup", np.asarray(sup_loss_log))
        np.save("avg_loss_unsup", np.asarray(unsup_loss_log))
        ss_vae.eval()
        final_test_accuracy = get_accuracy(data_loaders["test"],
                                           ss_vae.classifier, batch_size)
        print_and_log(
            logger,
            "best validation accuracy {} corresponding testing accuracy {} "
            "last testing accuracy {}".format(best_valid_acc,
                                              corresponding_test_acc,
                                              final_test_accuracy))

    finally:
        # close the logger file object if we opened it earlier
        logfile = True
        if logfile:
            logger.close()
Пример #5
0
def main(args):
    """
    run inference for CVAE
    :param args: arguments for CVAE
    :return: None
    """
    if args.seed is not None:
        set_seed(args.seed, args.cuda)

    # batch_size: number of images (and labels) to be considered in a batch
    cvae = CVAE(z_dim=args.z_dim,
                hidden_dim=args.hidden_dimension,
                   use_cuda=args.cuda)

    # setup the optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)}
    optimizer = ClippedAdam(adam_params)

    # set up the loss for inference.
    guide = config_enumerate(cvae.guide, args.enum_discrete)
    loss = SVI(cvae.model, guide, optimizer, loss=TraceEnum_ELBO(max_iarange_nesting=1))


    try:
        # setup the logger if a filename is provided
        logger = open(args.logfile, "w") if args.logfile else None

        data_loaders = setup_data_loaders(NHANES, args.cuda, args.batch_size)

        # initializing local variables to maintain the best validation acc
        # seen across epochs over the supervised training set
        # and the corresponding testing set and the state of the networks
        best_valid_acc, corresponding_test_acc = 0.0, 0.0

        # run inference for a certain number of epochs
        for i in range(0, args.num_epochs):

            # get the losses for an epoch
            epoch_losses = \
                run_inference_for_epoch(data_loaders, loss)

            # compute average epoch losses i.e. losses per example
            avg_epoch_losses = epoch_losses / NHANES.train_data_size

            # store the losses in the logfile
            str_loss = str(avg_epoch_losses)

            str_print = "{} epoch: avg loss {}".format(i, "{}".format(str_loss))

            validation_accuracy = get_accuracy(data_loaders["valid"], cvae.sim_measurements, args.batch_size)
            str_print += " validation accuracy {}".format(validation_accuracy)

            # this test accuracy is only for logging, this is not used
            # to make any decisions during training
            test_accuracy = get_accuracy(data_loaders["test"], cvae.sim_measurements, args.batch_size)
            str_print += " test accuracy {}".format(test_accuracy)

            # update the best validation accuracy and the corresponding
            # testing accuracy and the state of the parent module (including the networks)
            if best_valid_acc < validation_accuracy:
                best_valid_acc = validation_accuracy
                corresponding_test_acc = test_accuracy

            print_and_log(logger, str_print)

        final_test_accuracy = get_accuracy(data_loaders["test"], cvae.sim_measurements, args.batch_size)

        print_and_log(logger, "best validation accuracy {} corresponding testing accuracy {} "
                              "last testing accuracy {}".format(best_valid_acc, corresponding_test_acc,
                                                                final_test_accuracy))

    finally:
        # close the logger file object if we opened it earlier
        if args.logfile:
            logger.close()
Пример #6
0
def main():
    """
    run inference for SS-VAE
    :param args: arguments for SS-VAE
    :return: None
    """
    pyro.set_rng_seed(12345)
    cuda = True
    # batch_size: number of images (and labels) to be considered in a batch
    ss_vae = SSVAE(
        z_dim=50,
        hidden_layers=[500],
        use_cuda=cuda,
        config_enum="sequential",  #no idea
        aux_loss_multiplier=46)

    # setup the optimizer
    adam_params = {"lr": 0.00042, "betas": (0.9, 0.999)}
    optimizer = Adam(adam_params)

    # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
    # by enumerating each class label for the sampled discrete categorical distribution in the model
    jit = False
    guide = config_enumerate(ss_vae.guide, "sequential")  #, expand=True)
    elbo = (JitTraceEnum_ELBO if jit else TraceEnum_ELBO)()
    loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo)

    # build a list of all losses considered
    losses = [loss_basic]

    # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
    aux_loss = True
    if aux_loss:
        elbo = JitTrace_ELBO() if jit else Trace_ELBO()
        loss_aux = SVI(ss_vae.model_classify,
                       ss_vae.guide_classify,
                       optimizer,
                       loss=elbo)
        losses.append(loss_aux)

    batch_size = 200
    sup_num = 3000
    try:
        # setup the logger if a filename is provided
        logger = open('./tmp.log', "w") if './tmp.log' else None
        data_loaders = setup_data_loaders(MNISTCached,
                                          cuda,
                                          batch_size,
                                          sup_num=sup_num)

        # how often would a supervised batch be encountered during inference
        # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
        # until we have traversed through the all supervised batches
        periodic_interval_batches = int(MNISTCached.train_data_size /
                                        (1.0 * sup_num))

        # number of unsupervised examples
        unsup_num = MNISTCached.train_data_size - sup_num

        # initializing local variables to maintain the best validation accuracy
        # seen across epochs over the supervised training set
        # and the corresponding testing set and the state of the networks
        best_valid_acc, corresponding_test_acc = 0.0, 0.0

        # run inference for a certain number of epochs
        num_epochs = 10
        for i in range(0, num_epochs):

            # get the losses for an epoch
            epoch_losses_sup, epoch_losses_unsup = \
                run_inference_for_epoch(data_loaders, losses, periodic_interval_batches)

            # compute average epoch losses i.e. losses per example
            avg_epoch_losses_sup = map(lambda v: v / sup_num, epoch_losses_sup)
            avg_epoch_losses_unsup = map(lambda v: v / unsup_num,
                                         epoch_losses_unsup)

            # store the loss and validation/testing accuracies in the logfile
            str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
            str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))

            str_print = "{} epoch: avg losses {}".format(
                i, "{} {}".format(str_loss_sup, str_loss_unsup))

            validation_accuracy = get_accuracy(data_loaders["valid"],
                                               ss_vae.classifier, batch_size)
            str_print += " validation accuracy {}".format(validation_accuracy)

            # this test accuracy is only for logging, this is not used
            # to make any decisions during training
            test_accuracy = get_accuracy(data_loaders["test"],
                                         ss_vae.classifier, batch_size)
            str_print += " test accuracy {}".format(test_accuracy)

            # update the best validation accuracy and the corresponding
            # testing accuracy and the state of the parent module (including the networks)
            if best_valid_acc < validation_accuracy:
                best_valid_acc = validation_accuracy
                corresponding_test_acc = test_accuracy

            print_and_log(logger, str_print)

        final_test_accuracy = get_accuracy(data_loaders["test"],
                                           ss_vae.classifier, batch_size)
        print_and_log(
            logger,
            "best validation accuracy {} corresponding testing accuracy {} "
            "last testing accuracy {}".format(best_valid_acc,
                                              corresponding_test_acc,
                                              final_test_accuracy))

    finally:
        # close the logger file object if we opened it earlier
        logfile = True
        if logfile:
            logger.close()
Пример #7
0
def training(args, rel_embeddings, word_embeddings):
    if args.seed is not None:
        pyro.set_rng_seed(args.seed)
    # CUDA for PyTorch
    cuda_available = torch.cuda.is_available()
    if (cuda_available and args.cuda):
        device = torch.device("cuda")
        torch.cuda.set_device(0)
        print("using gpu acceleration")

    print("Generating Config")
    config = Config(
        word_embeddings=torch.FloatTensor(word_embeddings),
        decoder_hidden_dim=args.decoder_hidden_dim,
        num_relations=7,
        encoder_hidden_dim=args.encoder_hidden_dim,
        num_predicates=1000,
        batch_size=args.batch_size
    )

    # initialize the generator model
    generator = SimpleGenerator(config)

    # setup the optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta_1, 0.999)}
    optimizer = ClippedAdam(adam_params)

    # set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
    # by enumerating each class label for the sampled discrete categorical distribution in the model
    if args.enumerate:
        guide = config_enumerate(generator.guide, args.enum_discrete, expand=True)
    else:
        guide = generator.guide
    elbo = (JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO)(max_plate_nesting=1)
    loss_basic = SVI(generator.model, guide, optimizer, loss=elbo)

    # build a list of all losses considered
    losses = [loss_basic]

    # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
    if args.aux_loss:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        loss_aux = SVI(generator.model_identify, generator.guide_identify, optimizer, loss=elbo)
        losses.append(loss_aux)

    # prepare data
    real_model = RealModel(rel_embeddings, word_embeddings)
    data = real_model.generate_data()
    sup_train_set = data[:100]
    unsup_train_set = data[100:700]
    eval_set = data[700:900]
    test_set = data[900:]

    data_loaders = setup_data_loaders(sup_train_set,
                                      unsup_train_set,
                                      eval_set,
                                      test_set,
                                      batch_size=args.batch_size)

    num_train = len(sup_train_set) + len(unsup_train_set)
    num_eval = len(eval_set)
    num_test = len(test_set)

    # how often would a supervised batch be encountered during inference
    # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
    # until we have traversed through the all supervised batches
    periodic_interval_batches = int(1.0 * num_train / len(sup_train_set))

    # setup the logger if a filename is provided
    log_fn = "./logs/" + args.experiment_type + '/' + args.experiment_name + '.log'
    logger = open(log_fn, "w")

    # run inference for a certain number of epochs
    for i in tqdm(range(0, args.num_epochs)):
        # get the losses for an epoch
        epoch_losses_sup, epoch_losses_unsup = \
            train_epoch(data_loaders=data_loaders,
                        models=losses,
                        periodic_interval_batches=periodic_interval_batches)

        # compute average epoch losses i.e. losses per example
        avg_epoch_losses_sup = map(lambda v: v / len(sup_train_set), epoch_losses_sup)
        avg_epoch_losses_unsup = map(lambda v: v / len(unsup_train_set), epoch_losses_unsup)

        # store the loss and validation/testing accuracies in the logfile
        str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
        str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))
        str_print = "{} epoch: avg losses {}".format(i, "{} {}".format(str_loss_sup, str_loss_unsup))
        print_and_log(logger, str_print)

    # save trained models
    torch.save(generator.state_dict(), './models/test_generator_state_dict.pth')
    return generator
Пример #8
0
best_valid_kappa, cor_test_kappa, cor_epoch_1 = 0.0, 0.0, 0
best_test_kappa, cor_valid_kappa, cor_epoch_2 = 0.0, 0.0, 0

for epoch in range(num_epochs):

    ssvae.eval()
    epoch_losses_sup, epoch_losses_aux, epoch_losses_unsup = get_evaluate_for_epoch(
        train_dataloader, test_dataloader, loss_basic, loss_aux, use_cuda)
    epoch_losses_sup = epoch_losses_sup / len(train_dataset)
    epoch_losses_aux = epoch_losses_aux / len(train_dataset)
    epoch_losses_unsup = epoch_losses_unsup / len(test_dataset)
    ratio = np.abs(epoch_losses_sup / epoch_losses_aux)
    str_print = 'Epoch {0} (eval mode): epoch_losses_sup, {1:.3f}; epoch_losses_aux, {2:.3f}; ' \
                'epoch_losses_unsup, {3:.3f}; ratio, {4:.4f}'\
        .format(epoch, epoch_losses_sup, epoch_losses_aux, epoch_losses_unsup, ratio)
    print_and_log(logger, str_print)

    writer.add_scalar('Eval mode/epoch_losses_sup', epoch_losses_sup, epoch)
    writer.add_scalar('Eval mode/epoch_losses_aux', epoch_losses_aux, epoch)
    writer.add_scalar('Eval mode/epoch_losses_unsup', epoch_losses_unsup,
                      epoch)
    writer.add_scalar('Eval mode/ratio', ratio, epoch)
    '''Train'''
    ssvae.train()
    if epoch_losses_aux < 0.0012 and ratio > 400000:
        epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch_ncls(
            train_dataloader, test_dataloader, loss_basic, use_cuda)
    # elif ratio>10000:
    #     epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch(train_dataloader, test_dataloader,
    #                                                                    loss, use_cuda, ratio * factor)
    else: