gen_init=gen_init,
                        theta_y=theta_y,
                        disc_init=netD_state,
                        experiment=e)

    e.log.info(model)

    model.train()
    model.get_theta()
    model.save()

    pickle.dump([
        theta_y, model.all_theta, model.all_delta, model.all_weights,
        model.all_grad_norm, model.all_loss, model.all_tolerance
    ], open(e.experiment_dir + "/bayes_all+thetas+weights.pkl", "wb+"))


if __name__ == '__main__':

    args = config.get_base_parser().parse_args()

    with train_helper.experiment(args, args.save_prefix) as e:
        np.random.seed(e.config.random_seed)
        torch.manual_seed(e.config.random_seed)

        e.log.info("*" * 25 + " ARGS " + "*" * 25)
        e.log.info(args)
        e.log.info("*" * 25 + " ARGS " + "*" * 25)

        run(e)
示例#2
0
    v1_norm = (v1**2).sum(-1)**0.5
    v2_norm = (v2**2).sum(-1)**0.5
    return prod / (v1_norm * v2_norm)


save_dict = torch.load(args.save_file,
                       map_location=lambda storage, loc: storage)

config = save_dict['config']
checkpoint = save_dict['state_dict']
config.debug = True

with open(args.vocab_file, "rb") as fp:
    W, vocab = pickle.load(fp)

with train_helper.experiment(config, config.save_prefix) as e:
    e.log.info("vocab loaded from: {}".format(args.vocab_file))
    model = models.vgvae(vocab_size=len(vocab),
                         embed_dim=e.config.edim if W is None else W.shape[1],
                         embed_init=W,
                         experiment=e)
    model.eval()
    model.load(checkpointed_state_dict=checkpoint)
    e.log.info(model)

    def encode(d):
        global vocab, batch_size
        new_d = [[vocab.get(w, 0) for w in s.split(" ")] for s in d]
        all_y_vecs = []
        all_z_vecs = []