Example #1
0
    def __init__(self, decode_from, params, cuda=False):
        self.decode_from = decode_from
        self.params = params
        params.enc_nh = params.dec_nh  # not sure why this is necessary...

        self.train_data = MonoTextData(params.train_data, label=False)
        self.vocab = self.train_data.vocab
        self.vocab_size = len(self.vocab)

        # do I need these?
        model_init = uniform_initializer(0.01)
        emb_init = uniform_initializer(0.1)

        params.device = self.device = torch.device("cuda" if cuda else "cpu")

        self.encoder = LSTMEncoder(params, self.vocab_size, model_init,
                                   emb_init)
        self.decoder = LSTMDecoder(params, self.vocab, model_init, emb_init)

        self.vae = VAE(self.encoder, self.decoder, params).to(params.device)

        # assuming models were trained on a gpu...
        if cuda:
            self.vae.load_state_dict(torch.load(self.decode_from))
        else:
            self.vae.load_state_dict(
                torch.load(self.decode_from, map_location='cpu'))
def main():
    # Load MNIST image dataset
    mnist_train_data = datasets.MNIST(
        '/home/ajays/Downloads/',download=True,transform=transforms.ToTensor()
    )
    mnist_test_data = datasets.MNIST('/home/ajays/Downloads/',train=False,download=True)

    train_loader = torch.utils.data.DataLoader(
        mnist_train_data, batch_size = batch_size, shuffle=True
    )

    # Instantiation
    vae = VAE(n_inputs=32)

    # *********************
    # IMAGE VAE TRAINING
    # *********************
    # plot before training
    # o_before, mu, logvar = vae(mnist_train_data[0][0].reshape((1,1,28,28)))
    # plt.imshow(o_before.detach().numpy().reshape((28,28)))
    # plt.show()

    # train
    vae.load_state_dict(torch.load(LOAD_PATH))
    #vae = train_image_vae(vae, train_loader)

    # After training
    # o_after, mu, logvar = vae(example[0].reshape((1,1,28,28)))
    o_after = vae.decode(torch.randn((128)))
    plt.imshow(o_after.detach().numpy().reshape((28,28)))
    plt.show()
Example #3
0
def val_test(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    input_dim = 3
    model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    discriminators = {}

    if args.recons_loss == "gan":
        recons_disc = Discriminator(input_dim, args.img_res,
                                    args.input_type).to(args.device)
        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, 4), file=sys.stdout):
        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators, True)
        # if args.weights == "init" and epoch==1:
        # 	epoch+=1
        # 	break

        # print(z.shape)
        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

    print(val_loss_dict)
Example #4
0
def get_model(model, hidden_size, k, num_channels, resolution, num_classes):

    if model == "vqvae":
        if hidden_size == 256:
            CKPT_DIR = "models/imagenet/best.pt"
        elif hidden_size == 128:
            CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
        else:
            CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
            #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size)
        model = VectorQuantizedVAE(num_channels, hidden_size, k)
    elif model == "vae":
        CKPT_DIR = f"models/imagenet_hs_128_{hidden_size}_vae.pt"
        model = VAE(num_channels, hidden_size, hidden_size)

    elif model == "aae":
        CKPT_DIR = f"models/aae/imagenet_hs_32_{hidden_size}/best.pt"
        model = AAE(32, num_channels, hidden_size)
    else:
        model = None

    imgnetclassifier = ImgnetClassifier(model, hidden_size, resolution)

    if hidden_size > 0:
        ckpt = torch.load(CKPT_DIR)
        model.load_state_dict(ckpt)

    return imgnetclassifier
Example #5
0
def get_model(model,
              hidden_size,
              num_channels,
              resolution,
              enc_type,
              dec_type,
              num_classes,
              k=512):

    CKPT_DIR = f"models/{model}_{args.recons_loss}/{args.train_dataset}/depth_{enc_type}_{dec_type}_hs_{args.img_res}_{hidden_size}/best.pt"
    if model == "vqvae":
        #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size)
        model = VectorQuantizedVAE(num_channels, hidden_size, k, enc_type,
                                   dec_type)
    elif model == "vae":
        model = VAE(num_channels, hidden_size, enc_type, dec_type)
    elif model == "acai":
        model = ACAI(resolution, num_channels, hidden_size, enc_type, dec_type)
    else:
        model = None

    if model != "supervised":
        imgclassifier = ImgClassifier(model, hidden_size, resolution,
                                      num_classes)
    else:
        imgclassifier = SupervisedImgClassifier(hidden_size, enc_type,
                                                resolution, num_classes)

    if hidden_size > 0:
        ckpt = torch.load(CKPT_DIR)
        model.load_state_dict(ckpt["model"])

    return imgclassifier
Example #6
0
def get_vae_recons(loader, hidden_size=256):

    model = VAE(3, hidden_size, hidden_size).to(DEVICE)

    ckpt = torch.load("./models/imagenet_hs_128_256_vae.pt")
    model.load_state_dict(ckpt)
    args = type('', (), {})()
    args.device = DEVICE
    gen_img, _ = next(iter(loader))
    # grid = make_grid(gen_img.cpu(), nrow=8)
    # torchvision.utils.save_image(grid, "hs_{}_recons.png".format(hidden_size))
    #exit()

    reconstruction = vae.generate_samples(gen_img, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8)

    return grid
Example #7
0
def get_vae(model_encoder,
            model_decoder,
            tokenizer_encoder,
            tokenizer_decoder,
            beta=1):
    ArgsObj = namedtuple("Args", ["latent_size", "device", "fb_mode", "beta"])
    args = ArgsObj(latent_size=LATENT_SIZE_LARGE,
                   device=get_device(),
                   fb_mode=0,
                   beta=beta)

    checkpoint_full_dir = os.path.join(OUTPUT_DIR, "checkpoint-full-31250")
    if not torch.cuda.is_available():
        checkpoint = torch.load(os.path.join(checkpoint_full_dir,
                                             "training.bin"),
                                map_location="cpu")
    else:
        checkpoint = torch.load(
            os.path.join(checkpoint_full_dir, "training.bin"))

    model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder,
                    tokenizer_decoder, args)
    model_vae.load_state_dict(checkpoint["model_state_dict"])
    # logger.info("Pre-trained Optimus is successfully loaded")
    model_vae.to(args.device)
    return model_vae
Example #8
0
def create_model(args, vocab):
    # build initializers
    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    # build encoder
    if args.enc_type == 'lstm':
        encoder = LSTMEncoder(args, args.vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")
    
    # build decoder
    decoder = LSTMDecoder(args, vocab, model_init, emb_init)

    vae = VAE(encoder, decoder, args).to(args.device)
    
    return vae
Example #9
0
def get_model(ae_type, hidden_size, k, num_channels):
	if ae_type == "vqvae":
		if hidden_size==256:
			CKPT_DIR = "models/imagenet/best.pt" 
		elif hidden_size==128:
			CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
		else:
			CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
		model = VectorQuantizedVAE(num_channels, hidden_size, k)
		imgnetclassifier = ImgnetClassifier(model, hidden_size)
	elif ae_type == "vae":
		CKPT_DIR = "models/imagenet_vae.pt"
		model = VAE(num_channels, hidden_size, 4096)
		imgnetclassifier = ImgnetClassifier(model, 4)

	ckpt = torch.load(CKPT_DIR)
	model.load_state_dict(ckpt)

	return imgnetclassifier
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_data_file", default=None, type=str, required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--dataset", default=None, type=str, help="The dataset.")

    ## Other parameters
    parser.add_argument("--eval_data_file", default=None, type=str,
                        help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
    parser.add_argument("--ExpName", default="", type=str,
                        help="The experiment name used in Azure Table.")

    ## Encoder options
    parser.add_argument("--encoder_model_type", default="bert", type=str,
                        help="The encoder model architecture to be fine-tuned.")
    parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
                        help="The encoder model checkpoint for weights initialization.")
    parser.add_argument("--encoder_config_name", default="", type=str,
                        help="Optional pretrained config name or path if not the same as model_name_or_path")
    parser.add_argument("--encoder_tokenizer_name", default="", type=str,
                        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")

    ## Decoder options
    parser.add_argument("--decoder_model_type", default="gpt2", type=str,
                        help="The decoder model architecture to be fine-tuned.")
    parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
                        help="The decoder model checkpoint for weights initialization.")
    parser.add_argument("--decoder_config_name", default="", type=str,
                        help="Optional pretrained config name or path if not the same as model_name_or_path")
    parser.add_argument("--decoder_tokenizer_name", default="", type=str,
                        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")

    ## Variational auto-encoder
    parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
    parser.add_argument("--use_deterministic_connect", action='store_true',
                        help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
    parser.add_argument("--use_beta_schedule", action='store_true', help="Use cyclical beta schedule for auto-encoders.")

    ## Objective functions
    parser.add_argument("--mlm", action='store_true',
                        help="Train with masked-language modeling loss instead of language modeling.")
    parser.add_argument("--mlm_probability", type=float, default=0.15,
                        help="Ratio of tokens to mask for masked language modeling loss")
    parser.add_argument("--beta", type=float, default=1.0,
                        help="The weighting hyper-parameter of the KL term in VAE")


    parser.add_argument("--cache_dir", default="", type=str,
                        help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
    parser.add_argument("--max_seq_length", default=512, type=int,
                        help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
    parser.add_argument("--block_size", default=-1, type=int,
                        help="Optional input sequence length after tokenization."
                             "The training dataset will be truncated in block of this size for training."
                             "Default to the model max input length for single sentence inputs (take into account special tokens).")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Run evaluation during training at each logging step.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")


    # Training Schedules
    parser.add_argument("--ratio_increase", default=0.25, type=float,
                        help="Learning schedule, the percentage for the annealing stage.") 
    parser.add_argument("--ratio_zero", default=0.25, type=float,
                        help="Learning schedule, the percentage for the pure auto-encoding stage.")     
    parser.add_argument("--fb_mode", default=0, type=int,
                        help="free bit training mode.")   
    parser.add_argument("--dim_target_kl", default=3.0, type=float,
                        help="dim_target_kl free bit training mode.")                            
    parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs", default=1.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--use_philly", action='store_true',
                        help="Use Philly for computing.")

    ## IO: Logging and Saving
    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument('--gloabl_step_eval', type=int, default=661,
                        help="Evaluate the results at the given global step")

    # Precision & Distributed Training 
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")

    parser.add_argument('--world-size', default=ompi_size(), type=int, help='number of distributed processes')
    parser.add_argument('--dist-url', default='tcp://' + get_master_ip() + ':23456', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--port', type=str, default='51115', help="Port")

    args = parser.parse_args()

    args.dist_url = 'tcp://' + get_master_ip() + ':' + args.port

    # Setup logging
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)
    logger = logging.getLogger(__name__)

    rank_node = ompi_rank()
    args.distributed = args.world_size > 1
    logger.info("Rank {} distributed: {}".format(rank_node, args.distributed))

    if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
        raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
                         "flag (masked language modeling).")
    if args.eval_data_file is None and args.do_eval:
        raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
                         "or remove the --do_eval argument.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    if args.distributed:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=ompi_rank(),
            group_name='mtorch')
        logger.info("World Size is {}, Backend is {}, Init Method is {}, rank is {}".format(args.world_size, args.dist_backend, args.dist_url, ompi_rank()))

    gpus = list(gpu_indices())
    args.n_gpu = len(gpus)
    args.local_rank = ompi_rank() #gpus[0]
    torch.cuda.set_device(gpus[0])
    device = torch.device("cuda", gpus[0])

    args.device = device
    logger.info('Rank {}, gpus: {}, get_rank: {}'.format(rank_node, gpus, torch.distributed.get_rank()))
    logger.info(f'Local rank is {args.local_rank}, {rank_node}')

    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)

    args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_'  + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
    table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size) 
    if ompi_rank() == 0:
        try:
            ts.create_table(table_name)
        except:
            pass


    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    #if args.local_rank not in [-1, 0]: torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab

    ## Encoder 
    encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
    encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
    tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
    if args.block_size <= 0:
        args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
    model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
    # model_encoder.to(args.device)

    ## Decoder 
    decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
    decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
    tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
    if args.block_size <= 0:
        args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)
    setattr(decoder_config, "latent_size", args.latent_size)
    model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size)
    
    # Chunyuan: Add Padding token to GPT2
    special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
    num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens to GPT2')
    model_decoder.resize_token_embeddings(len(tokenizer_decoder))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
    assert tokenizer_decoder.pad_token == '<PAD>'

    #model_decoder.to(args.device)

    model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) #
    #model_vae.cuda()

    # Distributed training (should be after apex fp16 initialization)
    if args.distributed:
        # model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=gpus, output_device=args.local_rank, find_unused_parameters=True)
        model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=gpus)
    elif args.n_gpu > 1:
        model_vae = torch.nn.DataParallel(model_vae)#.to(args.device)

    # on_gpu = next(model_vae.parameters()).is_cuda

    #if args.local_rank == 0: torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab

    logger.info("Training/evaluation parameters %s", args)

    global_step=0
    if args.do_train:
        #if args.local_rank not in [-1, 0]: torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

        train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)

        #if args.local_rank == 0: torch.distributed.barrier()

        global_step, tr_loss = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name)
        logger.info("Rank %d, global_step = %s, average loss = %s", ompi_rank(), global_step, tr_loss)


    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        # Save model checkpoint
        output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
        output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
        if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
            os.makedirs(output_encoder_dir)
        if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
            os.makedirs(output_decoder_dir)

        logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
        logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`

        model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder  # Take care of distributed/parallel training
        model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder  # Take care of distributed/parallel training

        # Good practice: save your training arguments together with the trained model
        if args.use_philly:
            save_solid = False
            while not save_solid:
                try:
                    model_encoder_to_save.save_pretrained(output_encoder_dir)
                    torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
                    save_solid = True
                except:
                    pass
        else:
            model_encoder_to_save.save_pretrained(output_encoder_dir)
            torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))


        if args.use_philly:
            save_solid = False
            while not save_solid:
                try:
                    model_decoder_to_save.save_pretrained(output_decoder_dir)
                    torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
                    save_solid = True
                except:
                    pass
        else:
            model_decoder_to_save.save_pretrained(output_decoder_dir)
            torch.save(args, os.path.join(output_decoder_dir, 'training_encoder_args.bin'))
Example #11
0
    def __init__(self, config, vocab, rev_vocab, PAD_token=0):
        super(multiVAE, self).__init__()
        assert rev_vocab['<pad>'] == PAD_token
        self.vocab = vocab
        self.vocab_size = len(self.vocab)
        self.embed_size = config.emb_size
        self.hidden_size = config.n_hidden
        self.bow_size = config.bow_size
        self.rev_vocab = rev_vocab
        self.dropout = config.dropout
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.maxlen = config.maxlen
        self.clip = config.clip
        self.temp = config.temp
        self.full_kl_step = config.full_kl_step
        self.z_size = config.z_size
        self.init_w = config.init_weight
        self.softmax = nn.Softmax(dim=1)
        self.bidirectional = config.bidirectional
        self.lr_ae = config.lr_ae
        self.lr_vae = config.lr_vae

        # 如果LSTM双向,则两个方向拼接在一起
        self.encoder_output_size = self.hidden_size * (1 + int(self.bidirectional))
        # 标题和首句拼接在一起,
        self.context_dim = self.encoder_output_size * 2
        self.decoder_input_size = self.z_size

        # build components
        self.layers = nn.ModuleDict()
        self.layers["embedder"] = nn.Embedding(self.vocab_size, self.embed_size, padding_idx=PAD_token)
        # 对title, 每一句诗做编码, 默认双向LSTM,将最终的一维拼在一起
        self.layers["seq_encoder"] = Encoder(embedder=self.layers["embedder"], input_size=config.emb_size, hidden_size=config.n_hidden,
                                   bidirectional=self.bidirectional, n_layers=config.n_layers, noise_radius=config.noise_radius)

        # 先验网络
        self.layers["neg_vae"] = VAE(target_size=self.encoder_output_size, z_size=self.z_size, dropout=self.dropout, init_weight=self.init_w)
        self.layers["neu_vae"] = VAE(target_size=self.encoder_output_size, z_size=self.z_size, dropout=self.dropout, init_weight=self.init_w)
        self.layers["pos_vae"] = VAE(target_size=self.encoder_output_size, z_size=self.z_size, dropout=self.dropout, init_weight=self.init_w)

        # 词 Bow loss
        self.layers["bow_project_pos"] = nn.Sequential(
            nn.Linear(self.decoder_input_size, self.bow_size),
            nn.LeakyReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.bow_size, self.vocab_size)
        )
        self.layers["bow_project_neu"] = nn.Sequential(
            nn.Linear(self.decoder_input_size, self.bow_size),
            nn.LeakyReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.bow_size, self.vocab_size)
        )
        self.layers["bow_project_neg"] = nn.Sequential(
            nn.Linear(self.decoder_input_size, self.bow_size),
            nn.LeakyReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.bow_size, self.vocab_size)
        )

        # self.layers["decoder"] = Decoder(embedder=self.layers["embedder"], input_size=self.embed_size,
        #                        hidden_size=self.hidden_size,
        #                        vocab_size=self.vocab_size, n_layers=1)

        self.layers["vae_decoder_pos"] = Decoder(embedder=self.layers["embedder"], input_size=self.embed_size,
                           hidden_size=self.hidden_size,
                           vocab_size=self.vocab_size, n_layers=1)
        self.layers["vae_decoder_neu"] =  Decoder(embedder=self.layers["embedder"], input_size=self.embed_size,
                           hidden_size=self.hidden_size,
                           vocab_size=self.vocab_size, n_layers=1)
        self.layers["vae_decoder_neg"] = Decoder(embedder=self.layers["embedder"], input_size=self.embed_size,
                           hidden_size=self.hidden_size,
                           vocab_size=self.vocab_size, n_layers=1)

        self.layers["init_decoder"] = nn.Sequential(
            nn.Linear(self.decoder_input_size + self.context_dim, self.hidden_size),
            nn.BatchNorm1d(self.hidden_size, eps=1e-05, momentum=0.1),
            nn.LeakyReLU()
        )

        self.layers["init_decoder_hidden"] = nn.Sequential(
            nn.Linear(self.decoder_input_size, self.hidden_size),
            nn.BatchNorm1d(self.hidden_size, eps=1e-05, momentum=0.1),
            nn.LeakyReLU()
        )

        self.layers["init_decoder_hidden"].apply(self.init_weights)
        self.layers["bow_project_neg"].apply(self.init_weights)
        self.layers["bow_project_neu"].apply(self.init_weights)
        self.layers["bow_project_pos"].apply(self.init_weights)

        # self.optimizer_AE = optim.AdamW(list(self.layers["embedder"].parameters())
        #                                 + list(self.layers["seq_encoder"].parameters())
        #                                 + list(self.layers["vae_decoder_pos"].parameters())
        #                                 + list(self.layers["vae_decoder_neu"].parameters())
        #                                 + list(self.layers["vae_decoder_neg"].parameters())
        #                                 + list(self.layers["init_decoder"].parameters()), lr=self.lr_ae)
        self.optimizer_AE = {
            'pos': optim.AdamW(list(self.layers["embedder"].parameters())
                               + list(self.layers["init_decoder"].parameters())
                               + list(self.layers["seq_encoder"].parameters())
                               + list(self.layers["vae_decoder_pos"].parameters())
                               + list(self.layers["pos_vae"].parameters())
                               + list(self.layers["bow_project_pos"].parameters()), lr=self.lr_vae),
            'neu': optim.AdamW(list(self.layers["embedder"].parameters())
                               + list(self.layers["init_decoder"].parameters())
                               + list(self.layers["seq_encoder"].parameters())
                               + list(self.layers["vae_decoder_neu"].parameters())
                               + list(self.layers["neu_vae"].parameters())
                               + list(self.layers["bow_project_neu"].parameters()), lr=self.lr_vae),
            'neg': optim.AdamW(list(self.layers["embedder"].parameters())
                               + list(self.layers["init_decoder"].parameters())
                               + list(self.layers["seq_encoder"].parameters())
                               + list(self.layers["vae_decoder_neg"].parameters())
                               + list(self.layers["neg_vae"].parameters())
                               + list(self.layers["bow_project_neg"].parameters()), lr=self.lr_vae)
        }

        self.optimizer_VAE = {
            'pos': optim.AdamW(list(self.layers["vae_decoder_pos"].parameters())
                               + list(self.layers["pos_vae"].parameters())
                               + list(self.layers["bow_project_pos"].parameters()), lr=self.lr_vae),
            'neu': optim.AdamW(list(self.layers["vae_decoder_neu"].parameters())
                               + list(self.layers["neu_vae"].parameters())
                               + list(self.layers["bow_project_neu"].parameters()), lr=self.lr_vae),
            'neg': optim.AdamW(list(self.layers["vae_decoder_neg"].parameters())
                               + list(self.layers["neg_vae"].parameters())
                               + list(self.layers["bow_project_neg"].parameters()), lr=self.lr_vae)
        }

        # self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE, step_size=10, gamma=0.6)

        self.criterion_ce = nn.CrossEntropyLoss()
        self.softmax = nn.Softmax(dim=1)

        self.reconstruct_loss = dict()
Example #12
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train_data_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input training data file (a text file).")
    parser.add_argument(
        "--eval_data_file",
        default=None,
        type=str,
        help=
        "An input evaluation data file to evaluate the perplexity on (a text file)."
    )
    parser.add_argument("--checkpoint_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The directory where checkpoints are saved.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--dataset",
                        default='Snli',
                        type=str,
                        help="The dataset.")

    ## Variational auto-encoder
    parser.add_argument("--latent_size",
                        default=32,
                        type=int,
                        help="Latent space dimension.")
    parser.add_argument("--total_sents",
                        default=10,
                        type=int,
                        help="Total sentences to test recontruction.")
    parser.add_argument("--num_interpolation_steps",
                        default=10,
                        type=int,
                        help="Total sentences to test recontruction.")
    parser.add_argument("--play_mode",
                        default="interpolation",
                        type=str,
                        help="interpolation or reconstruction.")

    ## Encoder options
    parser.add_argument(
        "--encoder_model_type",
        default="bert",
        type=str,
        help="The encoder model architecture to be fine-tuned.")
    parser.add_argument(
        "--encoder_model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The encoder model checkpoint for weights initialization.")
    parser.add_argument(
        "--encoder_config_name",
        default="",
        type=str,
        help=
        "Optional pretrained config name or path if not the same as model_name_or_path"
    )
    parser.add_argument(
        "--encoder_tokenizer_name",
        default="",
        type=str,
        help=
        "Optional pretrained tokenizer name or path if not the same as model_name_or_path"
    )

    ## Decoder options
    parser.add_argument(
        "--decoder_model_type",
        default="gpt2",
        type=str,
        help="The decoder model architecture to be fine-tuned.")
    parser.add_argument(
        "--decoder_model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The decoder model checkpoint for weights initialization.")
    parser.add_argument(
        "--decoder_config_name",
        default="",
        type=str,
        help=
        "Optional pretrained config name or path if not the same as model_name_or_path"
    )
    parser.add_argument(
        "--decoder_tokenizer_name",
        default="",
        type=str,
        help=
        "Optional pretrained tokenizer name or path if not the same as model_name_or_path"
    )

    parser.add_argument("--per_gpu_train_batch_size",
                        default=1,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=1,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument('--gloabl_step_eval',
                        type=int,
                        default=661,
                        help="Evaluate the results at the given global step")

    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length"
    )

    # Interact with users
    parser.add_argument("--interact_with_user_input",
                        action='store_true',
                        help="Use user input to interact_with.")
    parser.add_argument("--sent_source", type=str, default="")
    parser.add_argument("--sent_target", type=str, default="")
    parser.add_argument("--sent_input", type=str, default="")
    parser.add_argument("--degree_to_target", type=float, default="1.0")

    ## Variational auto-encoder
    parser.add_argument("--nz",
                        default=32,
                        type=int,
                        help="Latent space dimension.")

    parser.add_argument("--prompt", type=str, default="")
    parser.add_argument("--padding_text", type=str, default="")
    parser.add_argument("--length", type=int, default=20)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        "--block_size",
        default=-1,
        type=int,
        help="Optional input sequence length after tokenization."
        "The training dataset will be truncated in block of this size for training."
        "Default to the model max input length for single sentence inputs (take into account special tokens)."
    )
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--use_philly",
                        action='store_true',
                        help="Use Philly for computing.")

    args = parser.parse_args()

    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()

    set_seed(args)

    args.encoder_model_type = args.encoder_model_type.lower()
    args.decoder_model_type = args.decoder_model_type.lower()

    global_step = args.gloabl_step_eval

    output_encoder_dir = os.path.join(
        args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
    output_decoder_dir = os.path.join(
        args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
    checkpoints = [[output_encoder_dir, output_decoder_dir]]
    logger.info("Evaluate the following checkpoints: %s", checkpoints)

    # Load a trained Encoder model and vocabulary that you have fine-tuned
    encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[
        args.encoder_model_type]
    model_encoder = encoder_model_class.from_pretrained(
        output_encoder_dir, latent_size=args.latent_size)
    tokenizer_encoder = encoder_tokenizer_class.from_pretrained(
        args.encoder_tokenizer_name
        if args.encoder_tokenizer_name else args.encoder_model_name_or_path,
        do_lower_case=args.do_lower_case)

    model_encoder.to(args.device)
    if args.block_size <= 0:
        args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size,
                          tokenizer_encoder.max_len_single_sentence)

    # Load a trained Decoder model and vocabulary that you have fine-tuned
    decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[
        args.decoder_model_type]
    model_decoder = decoder_model_class.from_pretrained(
        output_decoder_dir, latent_size=args.latent_size)
    tokenizer_decoder = decoder_tokenizer_class.from_pretrained(
        args.decoder_tokenizer_name
        if args.decoder_tokenizer_name else args.decoder_model_name_or_path,
        do_lower_case=args.do_lower_case)
    model_decoder.to(args.device)
    if args.block_size <= 0:
        args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size,
                          tokenizer_decoder.max_len_single_sentence)

    # Load full model
    output_full_dir = os.path.join(args.checkpoint_dir,
                                   'checkpoint-full-{}'.format(global_step))
    checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'),
                            map_location=torch.device('cpu'))

    # Chunyuan: Add Padding token to GPT2
    special_tokens_dict = {
        'pad_token': '<PAD>',
        'bos_token': '<BOS>',
        'eos_token': '<EOS>'
    }
    num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens to GPT2')
    model_decoder.resize_token_embeddings(
        len(tokenizer_decoder)
    )  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
    assert tokenizer_decoder.pad_token == '<PAD>'

    # Evaluation
    model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder,
                    tokenizer_decoder, args)
    model_vae.load_state_dict(checkpoint['model_state_dict'])
    logger.info("Pre-trained Optimus is successfully loaded")
    model_vae.to(args.device)

    if args.interact_with_user_input:

        if args.play_mode == 'interpolation':
            if len(args.sent_source) > 0 and len(args.sent_source) > 0:
                result = interpolate(model_vae, tokenizer_encoder,
                                     tokenizer_decoder, args)
            else:
                print('Please check: specify the source and target sentences!')

        if args.play_mode == 'analogy':
            if len(args.sent_source) > 0 and len(args.sent_source) > 0 and len(
                    args.sent_input) > 0:
                result = analogy(model_vae, tokenizer_encoder,
                                 tokenizer_decoder, args)
            else:
                print(
                    'Please check: specify the source, target and input analogy sentences!'
                )

    else:
        result = evaluate_latent_space(args,
                                       model_vae,
                                       tokenizer_encoder,
                                       tokenizer_decoder,
                                       prefix=global_step)
Example #13
0
def main(args):
    class uniform_initializer(object):
        def __init__(self, stdv):
            self.stdv = stdv

        def __call__(self, tensor):
            nn.init.uniform_(tensor, -self.stdv, self.stdv)

    class xavier_normal_initializer(object):
        def __call__(self, tensor):
            nn.init.xavier_normal_(tensor)

    if args.cuda:
        print('using cuda')

    print(args)

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    train_data = MonoTextData(args.train_data)

    vocab = train_data.vocab
    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, vocab=vocab)
    test_data = MonoTextData(args.test_data, vocab=vocab)

    print('Train data: %d samples' % len(train_data))
    print('finish reading datasets, vocab size is %d' % len(vocab))
    print('dropped sentences: %d' % train_data.dropped)
    sys.stdout.flush()

    log_niter = (len(train_data) // args.batch_size) // 10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    device = torch.device("cuda" if args.cuda else "cpu")
    args.device = device

    encoder = LSTMEncoder(args, vocab_size, model_init, emb_init)
    args.enc_nh = args.dec_nh

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)

    vae = VAE(encoder, decoder, args).to(device)

    if args.optim == 'sgd':
        enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=1.0)
        dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=1.0)
        opt_dict['lr'] = 1.0
    else:
        enc_optimizer = optim.Adam(vae.encoder.parameters(),
                                   lr=0.001,
                                   betas=(0.9, 0.999))
        dec_optimizer = optim.Adam(vae.decoder.parameters(),
                                   lr=0.001,
                                   betas=(0.9, 0.999))
        opt_dict['lr'] = 0.001

    iter_ = decay_cnt = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    pre_mi = -1
    aggressive_flag = True if args.aggressive else False
    vae.train()
    start = time.time()

    kl_weight = args.kl_start
    anneal_rate = (1.0 - args.kl_start) / (args.warm_up *
                                           (len(train_data) / args.batch_size))

    plot_data = train_data.data_sample(nsample=args.num_plot,
                                       device=device,
                                       batch_first=True)

    if args.plot_mode == 'multiple':
        grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1)
        plot_fn = plot_multiple

    elif args.plot_mode == 'single':
        grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1)
        plot_fn = plot_single
        posterior_mean = []
        infer_mean = []

        posterior_mean.append(
            vae.calc_model_posterior_mean(plot_data[0], grid_z))
        infer_mean.append(vae.calc_infer_mean(plot_data[0]))

    train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
                                                device=device,
                                                batch_first=True)

    test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
                                                  device=device,
                                                  batch_first=True)

    for epoch in range(args.epochs):
        report_kl_loss = report_rec_loss = 0
        report_num_words = report_num_sents = 0
        for i in np.random.permutation(len(train_data_batch)):
            if args.plot_mode == "single":
                batch_data, _ = plot_data

            else:
                batch_data = train_data_batch[i]
            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size

            report_num_sents += batch_size

            # kl_weight = 1.0
            kl_weight = min(1.0, kl_weight + anneal_rate)

            sub_iter = 1
            batch_data_enc = batch_data
            burn_num_words = 0
            burn_pre_loss = 1e4
            burn_cur_loss = 0
            while aggressive_flag and sub_iter < 100:

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                burn_batch_size, burn_sents_len = batch_data_enc.size()
                burn_num_words += (burn_sents_len - 1) * burn_batch_size

                loss, loss_rc, loss_kl = vae.loss(batch_data_enc,
                                                  kl_weight,
                                                  nsamples=args.nsamples)

                burn_cur_loss += loss.sum().item()
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                enc_optimizer.step()

                if args.plot_mode == "single":
                    batch_data_enc, _ = plot_data

                else:
                    id_ = np.random.random_integers(0,
                                                    len(train_data_batch) - 1)

                    batch_data_enc = train_data_batch[id_]

                if sub_iter % 15 == 0:
                    burn_cur_loss = burn_cur_loss / burn_num_words
                    if burn_pre_loss - burn_cur_loss < 0:
                        break
                    burn_pre_loss = burn_cur_loss
                    burn_cur_loss = burn_num_words = 0

                sub_iter += 1

            if args.plot_mode == 'single' and epoch == 0 and aggressive_flag:
                vae.eval()
                with torch.no_grad():
                    posterior_mean.append(posterior_mean[-1])
                    infer_mean.append(vae.calc_infer_mean(plot_data[0]))
                vae.train()

            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()

            loss, loss_rc, loss_kl = vae.loss(batch_data,
                                              kl_weight,
                                              nsamples=args.nsamples)

            loss = loss.mean(dim=-1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

            loss_rc = loss_rc.sum()
            loss_kl = loss_kl.sum()

            if not aggressive_flag:
                enc_optimizer.step()

            dec_optimizer.step()
            if args.plot_mode == 'single' and epoch == 0:
                vae.eval()
                with torch.no_grad():
                    posterior_mean.append(
                        vae.calc_model_posterior_mean(plot_data[0], grid_z))

                    if aggressive_flag:
                        infer_mean.append(infer_mean[-1])
                    else:
                        infer_mean.append(vae.calc_infer_mean(plot_data[0]))
                vae.train()

            report_rec_loss += loss_rc.item()
            report_kl_loss += loss_kl.item()

            if iter_ % log_niter == 0:
                train_loss = (report_rec_loss +
                              report_kl_loss) / report_num_sents
                if aggressive_flag or epoch == 0:
                    vae.eval()
                    mi = calc_mi(vae, val_data_batch)
                    vae.train()

                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi,
                           report_rec_loss / report_num_sents, time.time() - start))
                else:
                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
                           report_rec_loss / report_num_sents, time.time() - start))

                sys.stdout.flush()

                report_rec_loss = report_kl_loss = 0
                report_num_words = report_num_sents = 0

            if iter_ % args.plot_niter == 0 and epoch == 0:
                vae.eval()
                with torch.no_grad():
                    if args.plot_mode == 'single' and iter_ != 0:
                        plot_fn(infer_mean, posterior_mean, args)
                        return
                    elif args.plot_mode == "multiple":
                        plot_fn(vae, plot_data, grid_z, iter_, args)
                vae.train()

            iter_ += 1

            if aggressive_flag and (iter_ % len(train_data_batch)) == 0:
                vae.eval()
                cur_mi = calc_mi(vae, val_data_batch)
                vae.train()
                if cur_mi - pre_mi < 0:
                    aggressive_flag = False
                    print("STOP BURNING")

                pre_mi = cur_mi

                # return

        print('kl weight %.4f' % kl_weight)
        print('epoch: %d, VAL' % epoch)

        with torch.no_grad():
            plot_fn(vae, plot_data, grid_z, iter_, args)

        vae.eval()
        with torch.no_grad():
            loss, nll, kl, ppl = test(vae, val_data_batch, "VAL", args)

        if loss < best_loss:
            print('update best loss')
            best_loss = loss
            best_nll = nll
            best_kl = kl
            best_ppl = ppl
            torch.save(vae.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict["not_improved"] >= decay_epoch:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                vae.load_state_dict(torch.load(args.save_path))
                print('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                if args.optim == 'sgd':
                    enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                              lr=opt_dict["lr"])
                    dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                              lr=opt_dict["lr"])
                else:
                    enc_optimizer = optim.Adam(vae.encoder.parameters(),
                                               lr=opt_dict["lr"],
                                               betas=(0.5, 0.999))
                    dec_optimizer = optim.Adam(vae.decoder.parameters(),
                                               lr=opt_dict["lr"],
                                               betas=(0.5, 0.999))
        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, nll, kl, ppl = test(vae, test_data_batch, "TEST", args)

        vae.train()

    print('best_loss: %.4f, kl: %.4f, nll: %.4f, ppl: %.4f' \
          % (best_loss, best_kl, best_nll, best_ppl))

    sys.stdout.flush()

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))
    vae.eval()

    test_data_batch = test_data.create_data_batch(batch_size=1,
                                                  device=device,
                                                  batch_first=True)
    with torch.no_grad():
        calc_iwnll(vae, test_data_batch, args)
Example #14
0
def model_summary(model_type,
                  img_res,
                  hidden_size,
                  enc_type,
                  dec_type,
                  loss,
                  batch_size,
                  device=torch.device("cuda:1"),
                  verbose=True):
    pattern = re.compile(r"Params size \(MB\):(.*)\n")
    pattern2 = re.compile(r"Forward/backward pass size \(MB\):(.*)\n")
    input_dim = 3
    enc_input_size = (input_dim, img_res, img_res)
    dec_input_size = (hidden_size, img_res // 4, img_res // 4)
    pdb.set_trace()
    if verbose:
        print(f"model:{model_type}")
        print(f"depth:{enc_type}_{dec_type}")

    if model_type == "acai":
        model = ACAI(img_res, input_dim, hidden_size, enc_type,
                     dec_type).to(device)
    elif model_type == "vqvae":
        model = VectorQuantizedVAE(input_dim,
                                   hidden_size,
                                   enc_type=enc_type,
                                   dec_type=dec_type).to(device)
    elif model_type == "vae":
        model = VAE(input_dim,
                    hidden_size,
                    enc_type=enc_type,
                    dec_type=dec_type).to(device)

    encoder_summary, _ = torchsummary.summary_string(model.encoder,
                                                     enc_input_size,
                                                     device=device,
                                                     batch_size=batch_size)
    decoder_summary, _ = torchsummary.summary_string(model.decoder,
                                                     dec_input_size,
                                                     device=device,
                                                     batch_size=batch_size)
    if verbose:
        print(encoder_summary)
        print(decoder_summary)

    discriminators = {}

    if model_type == "acai":
        disc = Discriminator(input_dim, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["interp_disc"] = (disc_param_size, disc_forward_size)
    if loss == "gan":
        disc = Discriminator(input_dim, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif loss == "comp":
        disc = AnchorComparator(input_dim * 2, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif "comp_2" in loss:
        disc = ClubbedPermutationComparator(input_dim * 2, img_res,
                                            "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif "comp_6" in loss:
        disc = FullPermutationComparator(input_dim * 2, img_res,
                                         "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)

    encoder_param_size = float(re.search(pattern, encoder_summary).group(1))
    encoder_forward_size = float(re.search(pattern2, encoder_summary).group(1))
    decoder_param_size = float(re.search(pattern, decoder_summary).group(1))
    decoder_forward_size = float(re.search(pattern2, decoder_summary).group(1))

    if verbose:
        if "ACAI" in str(type(model)):
            print(
                f"discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}"
            )

        if loss == "gan":
            print(
                f"reconstruction discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}"
            )

        print(
            f"encoder:\n\tparams:{encoder_param_size}\n\tforward:{encoder_forward_size}"
        )
        print(
            f"decoder:\n\tparams:{decoder_param_size}\n\tforward:{decoder_forward_size}"
        )

    encoder = {"params": encoder_param_size, "forward": encoder_forward_size}
    decoder = {"params": decoder_param_size, "forward": decoder_forward_size}

    return encoder, decoder, discriminators
Example #15
0
def main(args):
    global logging
    debug = (args.reconstruct_from != ""
             or args.eval == True)  # don't make exp dir for reconstruction
    logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=debug)

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    train_data = MonoTextData(args.train_data, label=args.label)

    vocab = train_data.vocab
    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    logging('Train data: %d samples' % len(train_data))
    logging('finish reading datasets, vocab size is %d' % len(vocab))
    logging('dropped sentences: %d' % train_data.dropped)
    #sys.stdout.flush()

    log_niter = (len(train_data) // args.batch_size) // 10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    if args.load_path:
        loaded_state_dict = torch.load(args.load_path)
        #curr_state_dict = vae.state_dict()
        #curr_state_dict.update(loaded_state_dict)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

        if args.reset_dec:
            vae.decoder.reset_parameters(model_init, emb_init)

    if args.eval:
        logging('begin evaluation')
        vae.load_state_dict(torch.load(args.load_path))
        vae.eval()
        with torch.no_grad():
            test_data_batch = test_data.create_data_batch(
                batch_size=args.batch_size, device=device, batch_first=True)

            test(vae, test_data_batch, "TEST", args)
            au, au_var = calc_au(vae, test_data_batch)
            logging("%d active units" % au)
            # print(au_var)

            test_data_batch = test_data.create_data_batch(batch_size=1,
                                                          device=device,
                                                          batch_first=True)

            nll, ppl = calc_iwnll(vae, test_data_batch, args)
            logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))

        return

    if args.reconstruct_from != "":
        print("begin decoding")
        sys.stdout.flush()

        vae.load_state_dict(torch.load(args.reconstruct_from))
        vae.eval()
        with torch.no_grad():
            test_data_batch = test_data.create_data_batch(
                batch_size=args.batch_size, device=device, batch_first=True)
            # test(vae, test_data_batch, "TEST", args)
            reconstruct(vae, test_data_batch, vocab, args.decoding_strategy,
                        args.reconstruct_to)

        return

    if args.opt == "sgd":
        enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum)
        dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum)
        opt_dict['lr'] = args.lr
    elif args.opt == "adam":
        enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001)
        dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001)
        opt_dict['lr'] = 0.001
    else:
        raise ValueError("optimizer not supported")

    iter_ = decay_cnt = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    pre_mi = 0
    vae.train()
    start = time.time()

    train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
                                                device=device,
                                                batch_first=True)

    test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
                                                  device=device,
                                                  batch_first=True)

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in range(args.epochs):
            report_kl_loss = report_rec_loss = report_loss = 0
            report_num_words = report_num_sents = 0

            for i in np.random.permutation(len(train_data_batch)):

                batch_data = train_data_batch[i]
                batch_size, sent_len = batch_data.size()

                # not predict start symbol
                report_num_words += (sent_len - 1) * batch_size
                report_num_sents += batch_size

                kl_weight = args.beta

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                if args.iw_train_nsamples < 0:
                    loss, loss_rc, loss_kl = vae.loss(batch_data,
                                                      kl_weight,
                                                      nsamples=args.nsamples)
                else:
                    loss, loss_rc, loss_kl = vae.loss_iw(
                        batch_data,
                        kl_weight,
                        nsamples=args.iw_train_nsamples,
                        ns=ns)
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                loss_rc = loss_rc.sum()
                loss_kl = loss_kl.sum()

                enc_optimizer.step()
                dec_optimizer.step()

                report_rec_loss += loss_rc.item()
                report_kl_loss += loss_kl.item()
                report_loss += loss.item() * batch_size

                if iter_ % log_niter == 0:
                    #train_loss = (report_rec_loss  + report_kl_loss) / report_num_sents
                    train_loss = report_loss / report_num_sents
                    logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs, kl_weight %.4f' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
                           report_rec_loss / report_num_sents, time.time() - start, kl_weight))

                    #sys.stdout.flush()

                    report_rec_loss = report_kl_loss = report_loss = 0
                    report_num_words = report_num_sents = 0

                iter_ += 1

            logging('kl weight %.4f' % kl_weight)

            vae.eval()
            with torch.no_grad():
                loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args)
                au, au_var = calc_au(vae, val_data_batch)
                logging("%d active units" % au)
                # print(au_var)

            if args.save_ckpt > 0 and epoch <= args.save_ckpt:
                logging('save checkpoint')
                torch.save(
                    vae.state_dict(),
                    os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt'))

            if loss < best_loss:
                logging('update best loss')
                best_loss = loss
                best_nll = nll
                best_kl = kl
                best_ppl = ppl
                torch.save(vae.state_dict(), args.save_path)

            if loss > opt_dict["best_loss"]:
                opt_dict["not_improved"] += 1
                if opt_dict[
                        "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch:
                    opt_dict["best_loss"] = loss
                    opt_dict["not_improved"] = 0
                    opt_dict["lr"] = opt_dict["lr"] * lr_decay
                    vae.load_state_dict(torch.load(args.save_path))
                    logging('new lr: %f' % opt_dict["lr"])
                    decay_cnt += 1
                    enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                              lr=opt_dict["lr"],
                                              momentum=args.momentum)
                    dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                              lr=opt_dict["lr"],
                                              momentum=args.momentum)

            else:
                opt_dict["not_improved"] = 0
                opt_dict["best_loss"] = loss

            if decay_cnt == max_decay:
                break

            if epoch % args.test_nepoch == 0:
                with torch.no_grad():
                    loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST",
                                                 args)

            if args.save_latent > 0 and epoch <= args.save_latent:
                visualize_latent(args, epoch, vae, "cuda", test_data)

            vae.train()

    except KeyboardInterrupt:
        logging('-' * 100)
        logging('Exiting from training early')

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))

    vae.eval()
    with torch.no_grad():
        loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
        au, au_var = calc_au(vae, test_data_batch)
        logging("%d active units" % au)
        # print(au_var)

    test_data_batch = test_data.create_data_batch(batch_size=1,
                                                  device=device,
                                                  batch_first=True)
    with torch.no_grad():
        nll, ppl = calc_iwnll(vae, test_data_batch, args)
        logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))
Example #16
0
def main(args):

	train_loader, val_loader, test_loader = train_util.get_dataloaders(args)
	input_dim = 3
	model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type)
	opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)


	discriminators = {}

	if args.recons_loss != "mse":
		if args.recons_loss == "gan":
			recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device)
		elif args.recons_loss == "comp":
			recons_disc = AnchorComparator(input_dim*2, args.img_res, args.input_type).to(args.device)
		elif "comp_2" in args.recons_loss:
			recons_disc = ClubbedPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device)
		elif "comp_6" in args.recons_loss:
			recons_disc = FullPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device)

		recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True)
		
		discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

	if torch.cuda.device_count() > 1:
		model = train_util.ae_data_parallel(model)
		for disc in discriminators:
			discriminators[disc][0] = torch.nn.DataParallel(discriminators[disc][0])

	model.to(args.device)

	model_name = f"vae_{args.recons_loss}"
	if args.output_folder is None:
		args.output_folder = os.path.join(model_name, args.dataset, f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}")

	log_save_path = os.path.join("./logs", args.output_folder)
	model_save_path = os.path.join("./models", args.output_folder)

	if not os.path.exists(log_save_path):
		os.makedirs(log_save_path)
		print(f"log:{log_save_path}", file=sys.stderr)
		sys.stderr.flush()
	if not os.path.exists(model_save_path):
		os.makedirs(model_save_path)


	writer = SummaryWriter(log_save_path)

	print(f"train loader length:{len(train_loader)}", file=sys.stderr)
	best_loss = torch.tensor(np.inf)
	
	if args.weights == "load":
		start_epoch = train_util.load_state(model_save_path, model, opt, discriminators)
	else:
		start_epoch = 0

	recons_input_img = train_util.log_input_img_grid(test_loader, writer)

	train_util.save_recons_img_grid("val", recons_input_img, model, 0, args)


	for epoch in range(1, args.num_epochs):
		print("Epoch {}:".format(epoch))
		train(model, opt, train_loader)
		curr_loss = val(model, val_loader)
		# val_loss_dict, z = train_util.test(get_losses, model, val_loader, args, discriminators)

		print(f"epoch val loss:{curr_loss}", file=sys.stderr)
		sys.stderr.flush()
		train_util.save_recons_img_grid("val", recons_input_img, model, epoch+1, args)
		train_util.save_interp_img_grid("val", recons_input_img, model, epoch+1, args)
Example #17
0
    train=True,
    download=True,
    transform=preproc_transform,
),
                                           batch_size=BATCH_SIZE,
                                           shuffle=False,
                                           num_workers=NUM_WORKERS,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)(
    '../data/{}/'.format(DATASET), train=False, transform=preproc_transform),
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers=NUM_WORKERS,
                                          pin_memory=True)

model = VAE(INPUT_DIM, DIM, Z_DIM).cuda()
print(model)
opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)


def train():
    train_loss = []
    model.train()
    for batch_idx, (x, _) in enumerate(train_loader):
        start_time = time.time()
        x = x.cuda()

        x_tilde, kl_d = model(x)
        loss_recons = F.mse_loss(x_tilde, x, size_average=False) / x.size(0)
        loss = loss_recons + kl_d
Example #18
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_data_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--checkpoint_dir",
                        default=None,
                        type=str,
                        help="The directory where checkpoints are saved.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--dataset",
                        default=None,
                        type=str,
                        help="The dataset.")

    ## Other parameters
    parser.add_argument(
        "--eval_data_file",
        default=None,
        type=str,
        help=
        "An optional input evaluation data file to evaluate the perplexity on (a text file)."
    )
    parser.add_argument("--ExpName",
                        default="",
                        type=str,
                        help="The experiment name used in Azure Table.")
    parser.add_argument("--save_bert_gpt_init",
                        action='store_true',
                        help="Use Philly for computing.")
    parser.add_argument(
        "--length_weighted_loss",
        action='store_true',
        help="Use sentence length re-weight the reconstruction loss.")

    ## Decoder options
    parser.add_argument(
        "--decoder_model_type",
        default="gpt2",
        type=str,
        help="The decoder model architecture to be fine-tuned.")
    parser.add_argument(
        "--decoder_model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The decoder model checkpoint for weights initialization.")
    parser.add_argument(
        "--decoder_config_name",
        default="",
        type=str,
        help=
        "Optional pretrained config name or path if not the same as model_name_or_path"
    )
    parser.add_argument(
        "--decoder_tokenizer_name",
        default="",
        type=str,
        help=
        "Optional pretrained tokenizer name or path if not the same as model_name_or_path"
    )

    ## Variational auto-encoder
    parser.add_argument("--latent_size",
                        default=32,
                        type=int,
                        help="Latent space dimension.")
    parser.add_argument(
        "--use_deterministic_connect",
        action='store_true',
        help=
        "Use deterministic inference to generate latent codes, i.e., standard auto-encoders."
    )
    parser.add_argument(
        "--use_pretrained_model",
        action='store_true',
        help="Use pre-trained auto-encoder models as the initialization")
    parser.add_argument("--latent_as_gpt_memory",
                        default=1,
                        type=int,
                        help="Latent vector as memery for GPT2 to attend.")
    parser.add_argument("--latent_as_gpt_emb",
                        default=1,
                        type=int,
                        help="Latent vector as embeddings for GPT2.")

    ## Objective functions
    parser.add_argument(
        "--mlm",
        action='store_true',
        help=
        "Train with masked-language modeling loss instead of language modeling."
    )
    parser.add_argument(
        "--mlm_probability",
        type=float,
        default=0.15,
        help="Ratio of tokens to mask for masked language modeling loss")
    parser.add_argument(
        "--beta",
        type=float,
        default=1.0,
        help="The weighting hyper-parameter of the KL term in VAE")

    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)"
    )
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length"
    )
    parser.add_argument(
        "--block_size",
        default=-1,
        type=int,
        help="Optional input sequence length after tokenization."
        "The training dataset will be truncated in block of this size for training."
        "Default to the model max input length for single sentence inputs (take into account special tokens)."
    )
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_eval_rec",
        action='store_true',
        help="Whether to run eval reconstruction on a set of models.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Run evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    # Training Schedules
    parser.add_argument(
        "--ratio_increase",
        default=0.25,
        type=float,
        help="Learning schedule, the percentage for the annealing stage.")
    parser.add_argument(
        "--ratio_zero",
        default=0.25,
        type=float,
        help=
        "Learning schedule, the percentage for the pure auto-encoding stage.")
    parser.add_argument("--fb_mode",
                        default=0,
                        type=int,
                        help="free bit training mode.")
    parser.add_argument("--dim_target_kl",
                        default=3.0,
                        type=float,
                        help="dim_target_kl free bit training mode.")
    parser.add_argument("--per_gpu_train_batch_size",
                        default=4,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=1,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=1.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--use_philly",
                        action='store_true',
                        help="Use Philly for computing.")
    parser.add_argument(
        "--use_pretrained_vae",
        action='store_true',
        help=
        "Use use_pretrained_vae as initialization, where beta value is specified in the folder"
    )
    parser.add_argument("--use_random_weight",
                        action='store_true',
                        help="Use random weights as initialization")

    ## IO: Logging and Saving
    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gloabl_step_eval',
                        type=int,
                        default=661,
                        help="Evaluate the results at the given global step")

    # Precision & Distributed Training
    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="For distant debugging.")
    args = parser.parse_args()

    if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
        raise ValueError(
            "BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
            "flag (masked language modeling).")
    if args.eval_data_file is None and args.do_eval:
        raise ValueError(
            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
            "or remove the --do_eval argument.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    # # Setup distant debugging if needed
    # if args.server_ip and args.server_port:
    #     # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
    #     import ptvsd
    #     print("Waiting for debugger attach")
    #     ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
    #     ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(
        args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(
            args.dim_target_kl) + '_Ra_' + str(
                args.ratio_increase) + '_R0_' + str(args.ratio_zero)
    table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)

    try:
        ts.create_table(table_name)
    except:
        print("pass")
        pass

    # Set seed
    set_seed(args)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Barrier to make sure only the first process in distributed training download model & vocab

    # Load Optimius pre-trained model and tokenizer
    if args.use_pretrained_model:
        args.decoder_model_type = args.decoder_model_type.lower()

        global_step = args.gloabl_step_eval

        output_decoder_dir = os.path.join(
            args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
        output_full_dir = os.path.join(
            args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))

        checkpoints = [[output_decoder_dir]]
        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        # Load a trained Encoder model and vocabulary

        # Load a trained Decoder model and vocabulary
        decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[
            args.decoder_model_type]
        model_decoder = decoder_model_class.from_pretrained(
            output_decoder_dir, latent_size=args.latent_size)
        tokenizer_decoder = decoder_tokenizer_class.from_pretrained(
            args.decoder_tokenizer_name if args.decoder_tokenizer_name else
            args.decoder_model_name_or_path,
            do_lower_case=args.do_lower_case)
        model_decoder.to(args.device)

        if args.block_size <= 0:
            args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
            # print("block size ", args.block_size)
        args.block_size = min(args.block_size,
                              tokenizer_decoder.max_len_single_sentence)
        print("block size: ", args.block_size)

        # Load full model
        checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))

    else:
        # Load BERT and GPT weights (As an alternaive, one may train a VAE for this small)

        ## Decoder
        decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[
            args.decoder_model_type]
        decoder_config = decoder_config_class.from_pretrained(
            args.decoder_config_name if args.decoder_config_name else args.
            decoder_model_name_or_path)
        tokenizer_decoder = decoder_tokenizer_class.from_pretrained(
            args.decoder_tokenizer_name if args.decoder_tokenizer_name else
            args.decoder_model_name_or_path,
            do_lower_case=args.do_lower_case)
        if args.block_size <= 0:
            args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
        args.block_size = min(args.block_size,
                              tokenizer_decoder.max_len_single_sentence)

        if args.latent_as_gpt_emb + args.latent_as_gpt_memory == 0:
            return  # latent vector should pass into GPT to decode
        else:
            latent_as_gpt_emb = True if args.latent_as_gpt_emb == 1 else False
            latent_as_gpt_memory = True if args.latent_as_gpt_memory == 1 else False

        setattr(decoder_config, "latent_size", args.latent_size)
        model_decoder = decoder_model_class.from_pretrained(
            args.decoder_model_name_or_path,
            from_tf=bool('.ckpt' in args.decoder_model_name_or_path),
            config=decoder_config,
            latent_size=args.latent_size,
            latent_as_gpt_emb=latent_as_gpt_emb,
            latent_as_gpt_memory=latent_as_gpt_memory)

    # Save the init weights of BERT and GPT-2, so that we can load from local (Some infra requires so)
    if args.save_bert_gpt_init:

        decoder_path = os.path.join(
            args.output_dir,
            f"initial-models-tokenization-decoder-{args.latent_size}")
        if not os.path.exists(decoder_path): os.makedirs(decoder_path)
        model_decoder.save_pretrained(decoder_path)
        tokenizer_decoder.save_pretrained(decoder_path)

        return

    # Chunyuan: Add Padding token to GPT2
    special_tokens_dict = {
        'pad_token': '<PAD>',
        'bos_token': '<BOS>',
        'eos_token': '<EOS>'
    }
    num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens to GPT2')
    model_decoder.resize_token_embeddings(
        len(tokenizer_decoder)
    )  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
    assert tokenizer_decoder.pad_token == '<PAD>'

    # model_decoder.to(args.device)

    model_vae = AE(model_decoder, tokenizer_decoder, args)

    # pdb.set_trace()
    if args.use_random_weight:
        print("---" * 20)
        print("random weight init")
        print("---" * 20)
        model_vae.apply(weights_init_rondom)

    if args.use_pretrained_model:
        pre_model = checkpoint['model_state_dict']
        model_dict = model_vae.state_dict()

        pre_dict = {k: v for k, v in pre_model.items() if k in model_dict}
        model_dict.update(pre_dict)
        model_vae.load_state_dict(model_dict)

        # model_vae.load_state_dict(checkpoint['model_state_dict'])
        logger.info("Pre-trained Optimus is successfully loaded")
    model_vae.to(args.device)  #
    exit()
    # on_gpu = next(model_vae.parameters()).is_cuda

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # End of barrier to make sure only the first process in distributed training download model & vocab

    logger.info("Training/evaluation parameters %s", args)

    ##############################
    # Training
    global_step = 0
    if args.do_train:
        if args.local_rank not in [-1, 0]:
            torch.distributed.barrier(
            )  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

        print("===" * 20)
        train_dataloader = build_dataload_and_cache_examples(
            args, [tokenizer_decoder], evaluate=False)
        print("===" * 20)

        if args.local_rank == 0:
            torch.distributed.barrier()

        # print("+++"*20, "training", "+++"*20)
        global_step, tr_loss, optimizer = train(args, train_dataloader,
                                                model_vae, tokenizer_decoder,
                                                table_name)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)

    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        save_checkpoint(model_vae, optimizer, global_step, args)

    ##############################
    # Evaluation the metrics of VAE models, including PPL, MI, AU
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        if global_step == 0:
            global_step = args.gloabl_step_eval

        output_decoder_dir = os.path.join(
            args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
        output_full_dir = os.path.join(
            args.output_dir, 'checkpoint-full-{}'.format(global_step))
        checkpoint_dir = [output_decoder_dir, output_full_dir]

        logger.info("Evaluate the following checkpoint: %s",
                    checkpoint_dir[-1])
        global_step = checkpoint_dir[-1].split(
            '-')[-1] if len(checkpoint_dir) > 1 else ""

        checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
        model_vae.load_state_dict(checkpoint['model_state_dict'])
        logger.info(
            f"Pre-trained Optimus is successfully loaded: {output_full_dir}")
        model_vae.to(args.device)

        result = evaluate(args,
                          model_vae,
                          tokenizer_decoder,
                          table_name,
                          prefix=global_step,
                          subset='test')
        result = dict(
            (k + '_{}'.format(global_step), v) for k, v in result.items())
        results.update(result)

        output_eval_file = os.path.join(args.output_dir,
                                        "eval_vae_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(results.keys()):
                logger.info("%s = %s", key, str(results[key]))
                writer.write("%s = %s\n" % (key, str(results[key])))
        logger.info(
            f"The testing results are successfully saved: {output_eval_file}")

    ##############################
    #  Evaluate the reconstruction loss for each checkpoints;
    # This is used in studying two different latent vector injection schemes
    results = {}
    if args.do_eval_rec and args.local_rank in [-1, 0]:
        if global_step == 0:
            global_step = args.gloabl_step_eval
            # eval_steps = range(500, 13500, 500)
            # eval_steps = range(1000, 2000, 500)
            eval_steps = range(2000, 32000, 2000)

        checkpoints = []
        for e in eval_steps:
            output_decoder_dir = os.path.join(
                args.output_dir, 'checkpoint-decoder-{}'.format(e))
            checkpoints.append([output_decoder_dir])

        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint[0].split(
                '-')[-1] if len(checkpoints) > 1 else ""

            model_decoder = decoder_model_class.from_pretrained(checkpoint[1])
            model_decoder.to(args.device)

            model_vae = VAE(model_decoder, tokenizer_decoder,
                            args).to(args.device)

            result = evaluate_rec(args,
                                  model_vae,
                                  tokenizer_decoder,
                                  table_name,
                                  prefix=global_step,
                                  subset='test')
            result = dict((k + '_test_{}'.format(global_step), v)
                          for k, v in result.items())
            results.update(result)

            result = evaluate_rec(args,
                                  model_vae,
                                  tokenizer_decoder,
                                  table_name,
                                  prefix=global_step,
                                  subset='train')
            result = dict((k + '_train_{}'.format(global_step), v)
                          for k, v in result.items())
            results.update(result)

            # pdb.set_trace()

        output_eval_file = os.path.join(args.output_dir,
                                        "eval_rec_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(results.keys()):
                logger.info("%s = %s", key, str(results[key]))
                writer.write("%s = %s\n" % (key, str(results[key])))
        logger.info(
            f"The testing results are successfully saved: {output_eval_file}")

    return results
Example #19
0
def main(args):

    input_dim = 3
    model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type)

    opt = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-6)

    # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    discriminators = {}

    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    if torch.cuda.device_count() > 1:
        model = train_util.ae_data_parallel(model)
        for disc in discriminators:
            discriminators[disc][0] = torch.nn.DataParallel(
                discriminators[disc][0])

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    print("model built", file=sys.stderr)
    #print("model created")
    train_loader, val_loader, test_loader = train_util.get_dataloaders(args)
    print("loaders acquired", file=sys.stderr)
    #print("loaders acquired")

    model_name = f"vae_{args.recons_loss}"
    if args.output_folder is None:
        args.output_folder = os.path.join(
            model_name, args.dataset,
            f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}"
        )

    log_save_path = os.path.join("./logs", args.output_folder)
    model_save_path = os.path.join("./models", args.output_folder)

    if not os.path.exists(log_save_path):
        os.makedirs(log_save_path)
        print(f"log:{log_save_path}", file=sys.stderr)
        sys.stderr.flush()
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)

    writer = SummaryWriter(log_save_path)

    print(f"train loader length:{len(train_loader)}", file=sys.stderr)
    best_loss = torch.tensor(np.inf)

    if args.weights == "load":
        start_epoch = train_util.load_state(model_save_path, model, opt,
                                            discriminators)
    else:
        start_epoch = 0

    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    stop_patience = args.stop_patience
    for epoch in range(start_epoch, args.num_epochs):

        try:
            train(model, train_loader, opt, epoch, writer, args,
                  discriminators)
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******", file=sys.stderr)
            print(err, file=sys.stderr)
            exit(0)

        val_loss_dict, z = train_util.test(get_losses, model, val_loader, args,
                                           discriminators)
        print(f"epoch loss:{val_loss_dict['recons_loss'].item()}")

        train_util.save_recons_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)
        train_util.save_interp_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, opt, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, model_save_path)
def main(args):
    writer = SummaryWriter('./logs_vae/{0}'.format(args.output_folder))
    save_filename = './models_vae/{0}'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':

            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VAE(num_channels, args.hidden_size, args.z).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        train(epoch, train_loader, model, optimizer, args, writer)
        loss = test(valid_loader, model, args, writer)

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1),
                  'wb') as f:
            torch.save(model.state_dict(), f)
Example #21
0
def main(args, args_model):
    global logging
    logging = get_logger_existing_dir(os.path.dirname(args.load_path),
                                      'log_classifier.txt')

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    vocab = {}
    if getattr(args, 'vocab_file', None) is not None:
        with open(args.vocab_file) as fvocab:
            for i, line in enumerate(fvocab):
                vocab[line.strip()] = i

        vocab = VocabEntry(vocab)

    filename_glob = args.train_data + '.seed_*.n_' + str(
        args.num_label_per_class)
    train_sets = glob.glob(filename_glob)
    print("Train sets:", train_sets)

    main_train_data = MonoTextData(args.train_data,
                                   label=args.label,
                                   vocab=vocab)
    vocab = main_train_data.vocab
    vocab_size = len(vocab)

    logging('finish reading datasets, vocab size is %d' % len(vocab))
    #sys.stdout.flush()

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args_model.device = device

    if args_model.enc_type == 'lstm':
        args_model.pooling = getattr(args_model, 'pooling', None)
        encoder = GaussianLSTMEncoder(
            args_model,
            vocab_size,
            model_init,
            emb_init,
            pooling=args_model.pooling,
        )

    elif args_model.enc_type in ['max_avg_pool', 'max_pool', 'avg_pool']:
        args_model.skip_first_word = getattr(args_model, 'skip_first_word',
                                             None)
        encoder = GaussianPoolEncoder(
            args_model,
            vocab_size,
            model_init,
            emb_init,
            enc_type=args_model.enc_type,
            skip_first_word=args_model.skip_first_word)
        #args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    args_model.encode_length = getattr(args_model, 'encode_length', None)
    if args_model.dec_type == 'lstm':
        decoder = LSTMDecoder(args_model, vocab, model_init, emb_init,
                              args_model.encode_length)
    elif args_model.dec_type == 'unigram':
        decoder = UnigramDecoder(args_model, vocab, model_init, emb_init)

    vae = VAE(encoder, decoder, args_model,
              args_model.encode_length).to(device)

    if args.load_path:
        print("load args!")
        print(vae)
        loaded_state_dict = torch.load(args.load_path)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

    vae.eval()

    def preprocess(data_fn):
        codes, labels = read_dataset(data_fn, vocab, device, vae,
                                     args.classify_using_samples)
        if args.classify_using_samples:
            is_gaussian_enc = codes.shape[1] == (vae.encoder.nz * 2)
            codes = augment_dataset(codes, 1, is_gaussian_enc,
                                    vae)  # use only 1 sample for test
        codes = codes.cpu().numpy()
        labels = labels.cpu().numpy()
        return codes, labels

    test_codes, test_labels = preprocess(args.test_data)

    test_f1_scores = []
    average_f1 = 'macro'
    f1_scorer = make_scorer(f1_score,
                            average=average_f1,
                            labels=np.unique(test_labels),
                            greater_is_better=True)
    # log loss: negative log likelihood. We should minimize that, so greater_is_better=False
    log_loss_scorer = make_scorer(log_loss,
                                  needs_proba=True,
                                  greater_is_better=False)
    warnings.filterwarnings('ignore')
    results = {
        'n_samples_per_class': args.num_label_per_class,
    }
    n_repeats = args.n_repeats

    n_splits = min(args.num_label_per_class, 5)
    for i, fn in enumerate(train_sets):
        codes, labels = preprocess(fn)
        if args.resample > 1:
            # going to augment the training set by sampling
            # then create a new cross validation function to get the correct indices
            cross_val = augment_cross_val(labels, args.resample, n_splits,
                                          n_repeats)
            labels = np.repeat(labels, args.resample)
        else:
            cross_val = RepeatedStratifiedKFold(n_splits=n_splits,
                                                n_repeats=n_repeats)

        scaler = StandardScaler()
        codes = scaler.fit_transform(codes)
        scaled_test_codes = scaler.transform(test_codes)
        gridsearch = GridSearchCV(
            LogisticRegression(solver='sag', multi_class='auto'),
            {
                "penalty": ['l2'],
                "C": [0.01, 0.1, 1, 10, 100],
            },
            cv=cross_val,
            scoring={
                "f1": f1_scorer,
                "log": log_loss_scorer,
            },
            refit=False,
        )
        clf = gridsearch
        clf.fit(codes, labels)
        crossval_f1, test_f1 = refit_and_eval(
            'f1',
            clf,
            clf.cv_results_,
            codes,
            labels,
            scaled_test_codes,
            test_labels,
            f1_scorer,
        )
        crossval_log, test_log_loss = refit_and_eval(
            'log',
            clf,
            clf.cv_results_,
            codes,
            labels,
            scaled_test_codes,
            test_labels,
            log_loss_scorer,
        )
        results[i] = {
            "F1": {
                'crossval': crossval_f1,
                'test': test_f1
            },
            "log": {
                'crossval': crossval_log,
                'test': test_log_loss
            },
        }
        print(results[i])

    if args.classify_using_samples:
        n_per_class = str(args.num_label_per_class)
        resample = 1 if args.resample == -1 else args.resample
        output_fn = os.path.join(
            args.exp_dir,
            'results_sample_' + str(resample) + '_' + n_per_class + '.json')
    else:
        output_fn = os.path.join(args.exp_dir,
                                 'results_' + n_per_class + '.json')
    with open(output_fn, 'w') as f:
        json.dump(results, f)
Example #22
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--input_file_path",
        default=None,
        type=str,
        required=True,
        help="The output directory where the input files will be written.")
    parser.add_argument(
        "--output_file_path",
        default=None,
        type=str,
        required=True,
        help="The output directory where the output files will be written.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the logs and results will be saved.")
    parser.add_argument("--dataset",
                        default=None,
                        type=str,
                        help="The dataset.")

    ## Other parameters
    parser.add_argument("--ExpName",
                        default="",
                        type=str,
                        help="The experiment name used in Azure Table.")

    ## Encoder options
    parser.add_argument(
        "--encoder_model_type",
        default="bert",
        type=str,
        help="The encoder model architecture to be fine-tuned.")
    parser.add_argument(
        "--encoder_model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The encoder model checkpoint for weights initialization.")
    parser.add_argument(
        "--encoder_config_name",
        default="",
        type=str,
        help=
        "Optional pretrained config name or path if not the same as model_name_or_path"
    )
    parser.add_argument(
        "--encoder_tokenizer_name",
        default="",
        type=str,
        help=
        "Optional pretrained tokenizer name or path if not the same as model_name_or_path"
    )

    ## Decoder options
    parser.add_argument(
        "--decoder_model_type",
        default="gpt2",
        type=str,
        help="The decoder model architecture to be fine-tuned.")
    parser.add_argument(
        "--decoder_model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The decoder model checkpoint for weights initialization.")
    parser.add_argument(
        "--decoder_config_name",
        default="",
        type=str,
        help=
        "Optional pretrained config name or path if not the same as model_name_or_path"
    )
    parser.add_argument(
        "--decoder_tokenizer_name",
        default="",
        type=str,
        help=
        "Optional pretrained tokenizer name or path if not the same as model_name_or_path"
    )

    ## Variational auto-encoder
    parser.add_argument("--latent_size",
                        default=32,
                        type=int,
                        help="Latent space dimension.")
    parser.add_argument(
        "--use_deterministic_connect",
        action='store_true',
        help=
        "Use deterministic inference to generate latent codes, i.e., standard auto-encoders."
    )

    ## Objective functions
    parser.add_argument(
        "--mlm",
        action='store_true',
        help=
        "Train with masked-language modeling loss instead of language modeling."
    )
    parser.add_argument(
        "--mlm_probability",
        type=float,
        default=0.15,
        help="Ratio of tokens to mask for masked language modeling loss")
    parser.add_argument(
        "--beta",
        type=float,
        default=1.0,
        help="The weighting hyper-parameter of the KL term in VAE")

    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)"
    )
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length"
    )
    parser.add_argument(
        "--block_size",
        default=-1,
        type=int,
        help="Optional input sequence length after tokenization."
        "The training dataset will be truncated in block of this size for training."
        "Default to the model max input length for single sentence inputs (take into account special tokens)."
    )
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Run evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    # Training Schedules
    parser.add_argument(
        "--ratio_increase",
        default=0.25,
        type=float,
        help="Learning schedule, the percentage for the annealing stage.")
    parser.add_argument(
        "--ratio_zero",
        default=0.25,
        type=float,
        help=
        "Learning schedule, the percentage for the pure auto-encoding stage.")
    parser.add_argument("--fb_mode",
                        default=0,
                        type=int,
                        help="free bit training mode.")
    parser.add_argument("--dim_target_kl",
                        default=3.0,
                        type=float,
                        help="dim_target_kl free bit training mode.")
    parser.add_argument("--per_gpu_train_batch_size",
                        default=4,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=1,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=1.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--use_philly",
                        action='store_true',
                        help="Use Philly for computing.")

    ## IO: Logging and Saving
    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gloabl_step_eval',
                        type=int,
                        default=661,
                        help="Evaluate the results at the given global step")

    # Precision & Distributed Training
    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="For distant debugging.")
    args = parser.parse_args()

    if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
        raise ValueError(
            "BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
            "flag (masked language modeling).")

    if os.path.exists(args.output_file_path) and os.listdir(
            args.output_file_path
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_file_path))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    logger.info(f'Local rank is {args.local_rank}')
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(
        args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(
            args.dim_target_kl) + '_Ra_' + str(
                args.ratio_increase) + '_R0_' + str(args.ratio_zero)
    table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
    try:
        ts.create_table(table_name)
    except:
        pass

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Barrier to make sure only the first process in distributed training download model & vocab

    ## Encoder
    encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[
        args.encoder_model_type]
    encoder_config = encoder_config_class.from_pretrained(
        args.encoder_config_name if args.encoder_config_name else args.
        encoder_model_name_or_path)
    tokenizer_encoder = encoder_tokenizer_class.from_pretrained(
        args.encoder_tokenizer_name
        if args.encoder_tokenizer_name else args.encoder_model_name_or_path,
        do_lower_case=args.do_lower_case)
    if args.block_size <= 0:
        args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size,
                          tokenizer_encoder.max_len_single_sentence)
    model_encoder = encoder_model_class.from_pretrained(
        args.encoder_model_name_or_path,
        from_tf=bool('.ckpt' in args.encoder_model_name_or_path),
        config=encoder_config,
        latent_size=args.latent_size)
    # model_encoder.to(args.device)

    ## Decoder
    decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[
        args.decoder_model_type]
    decoder_config = decoder_config_class.from_pretrained(
        args.decoder_config_name if args.decoder_config_name else args.
        decoder_model_name_or_path)
    tokenizer_decoder = decoder_tokenizer_class.from_pretrained(
        args.decoder_tokenizer_name
        if args.decoder_tokenizer_name else args.decoder_model_name_or_path,
        do_lower_case=args.do_lower_case)
    if args.block_size <= 0:
        args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size,
                          tokenizer_decoder.max_len_single_sentence)
    model_decoder = decoder_model_class.from_pretrained(
        args.decoder_model_name_or_path,
        from_tf=bool('.ckpt' in args.decoder_model_name_or_path),
        config=decoder_config,
        latent_size=args.latent_size)

    # Chunyuan: Add Padding token to GPT2
    special_tokens_dict = {
        'pad_token': '<PAD>',
        'bos_token': '<BOS>',
        'eos_token': '<EOS>'
    }
    num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens to GPT2')
    model_decoder.resize_token_embeddings(
        len(tokenizer_decoder)
    )  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
    assert tokenizer_decoder.pad_token == '<PAD>'

    # model_decoder.to(args.device)

    model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder,
                    tokenizer_decoder, args).to(args.device)  #

    # on_gpu = next(model_vae.parameters()).is_cuda

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # End of barrier to make sure only the first process in distributed training download model & vocab

    logger.info("Training/evaluation parameters %s", args)

    global_step = 0
    # Training
    if args.do_train:
        if args.local_rank not in [-1, 0]:
            torch.distributed.barrier(
            )  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

        train_dataloader = build_dataload_and_cache_examples(
            args, [tokenizer_encoder, tokenizer_decoder], evaluate=False)

        if args.local_rank == 0:
            torch.distributed.barrier()

        num_collected, num_dropped = train(args, train_dataloader, model_vae,
                                           tokenizer_encoder,
                                           tokenizer_decoder, table_name)
        logger.info(" num_collected = %s, num_dropped = %s", num_collected,
                    num_dropped)
Example #23
0
def main(args):
    train_data = MonoTextData(args.train_data, label=args.label)
    vocab = train_data.vocab
    vocab_size = len(vocab)
    
    vocab_path = os.path.join("/".join(args.train_data.split("/")[:-1]), "vocab.txt")
    with open(vocab_path, "w") as fout:
        for i in range(vocab_size):
            fout.write("{}\n".format(vocab.id2word(i)))
        #return

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    print('Train data: %d samples' % len(train_data))
    print('finish reading datasets, vocab size is %d' % len(vocab))
    print('dropped sentences: %d' % train_data.dropped)
    sys.stdout.flush()

    log_niter = (len(train_data)//args.batch_size)//10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    print('begin evaluation')
    vae.load_state_dict(torch.load(args.load_path))
    vae.eval()
    with torch.no_grad():
        test_data_batch, test_batch_labels = test_data.create_data_batch_labels(batch_size=args.batch_size,
                                                      device=device,
                                                      batch_first=True)

        # test(vae, test_data_batch, "TEST", args)
        # au, au_var = calc_au(vae, test_data_batch)
        # print("%d active units" % au)

        train_data_batch, train_batch_labels = train_data.create_data_batch_labels(batch_size=args.batch_size,
                                                        device=device,
                                                        batch_first=True)

        val_data_batch, val_batch_labels = val_data.create_data_batch_labels(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

        print("getting  vectors for training")
        print(args.save_dir)
        save_latents(args, vae, train_data_batch, train_batch_labels, "train")
        print("getting  vectors for validating")
        save_latents(args, vae, val_data_batch, val_batch_labels, "val")
        print("getting  vectors for testing")
        save_latents(args, vae, test_data_batch, test_batch_labels, "test")
Example #24
0
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=1,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=1,
                                          pin_memory=True)
print('train_loader', len(train_loader))
print('test_loader', len(test_loader))

#%%
model = VAE().to(device)
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
optimizer = optim.Adam(model.parameters(),
                       lr=1e-3,
                       betas=(0.9, 0.999),
                       weight_decay=0.0005)

#%%
epochs = 100
viz = Visdom()
global plotter, recon
plotter = utils.VisdomLinePlotter(env_name='main')
sample_image = utils.VisdomImage(env_name='main')
recon = utils.VisdomImage(env_name='main')

for epoch in range(1, epochs + 1):
Example #25
0
def main(args):
    if args.save_path == '':
        make_savepath(args)
        seed(args)

    if args.cuda:
        print('using cuda')

    print(args)

    device = torch.device("cuda" if args.cuda else "cpu")
    args.device = device

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    all_data = torch.load(args.data_file)
    x_train, x_val, x_test = all_data

    x_train = x_train.to(device)
    x_val = x_val.to(device)
    x_test = x_test.to(device)
    y_size = 1
    y_train = x_train.new_zeros(x_train.size(0), y_size)
    y_val = x_train.new_zeros(x_val.size(0), y_size)
    y_test = x_train.new_zeros(x_test.size(0), y_size)
    print(torch.__version__)
    train_data = torch.utils.data.TensorDataset(x_train, y_train)
    val_data = torch.utils.data.TensorDataset(x_val, y_val)
    test_data = torch.utils.data.TensorDataset(x_test, y_test)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=True)
    print('Train data: %d batches' % len(train_loader))
    print('Val data: %d batches' % len(val_loader))
    print('Test data: %d batches' % len(test_loader))
    sys.stdout.flush()

    log_niter = len(train_loader) // 5

    encoder = ResNetEncoderV2(args)
    decoder = PixelCNNDecoderV2(args)

    vae = VAE(encoder, decoder, args).to(device)

    if args.sample_from != '':
        save_dir = "samples/%s" % args.dataset
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        vae.load_state_dict(torch.load(args.sample_from))
        vae.eval()
        with torch.no_grad():
            sample_z = vae.sample_from_prior(400).to(device)
            sample_x, sample_probs = vae.decode(sample_z, False)
        image_file = 'sample_binary_from_%s.png' % (
            args.sample_from.split('/')[-1][:-3])
        save_image(sample_x.data.cpu(),
                   os.path.join(save_dir, image_file),
                   nrow=20)
        image_file = 'sample_cont_from_%s.png' % (
            args.sample_from.split('/')[-1][:-3])
        save_image(sample_probs.data.cpu(),
                   os.path.join(save_dir, image_file),
                   nrow=20)

        return

    if args.eval:
        print('begin evaluation')
        test_loader = torch.utils.data.DataLoader(test_data,
                                                  batch_size=50,
                                                  shuffle=True)
        vae.load_state_dict(torch.load(args.load_path))
        vae.eval()
        with torch.no_grad():
            test(vae, test_loader, "TEST", args)
            au, au_var = calc_au(vae, test_loader)
            print("%d active units" % au)
            # print(au_var)

            calc_iwnll(vae, test_loader, args)

        return

    enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001)
    dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001)
    opt_dict['lr'] = 0.001

    iter_ = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    decay_cnt = pre_mi = best_mi = mi_not_improved = 0
    aggressive_flag = True if args.aggressive else False
    vae.train()
    start = time.time()

    kl_weight = args.kl_start
    anneal_rate = (1.0 - args.kl_start) / (args.warm_up * len(train_loader))

    for epoch in range(args.epochs):
        report_kl_loss = report_rec_loss = 0
        report_num_examples = 0
        for datum in train_loader:
            batch_data, _ = datum
            batch_data = torch.bernoulli(batch_data)
            batch_size = batch_data.size(0)

            report_num_examples += batch_size

            # kl_weight = 1.0
            kl_weight = min(1.0, kl_weight + anneal_rate)

            sub_iter = 1
            batch_data_enc = batch_data
            burn_num_examples = 0
            burn_pre_loss = 1e4
            burn_cur_loss = 0
            while aggressive_flag and sub_iter < 100:

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                burn_num_examples += batch_data_enc.size(0)
                loss, loss_rc, loss_kl = vae.loss(batch_data_enc,
                                                  kl_weight,
                                                  nsamples=args.nsamples)

                burn_cur_loss += loss.sum().item()
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                enc_optimizer.step()

                id_ = np.random.choice(x_train.size(0),
                                       args.batch_size,
                                       replace=False)

                batch_data_enc = torch.bernoulli(x_train[id_])

                if sub_iter % 10 == 0:
                    burn_cur_loss = burn_cur_loss / burn_num_examples
                    if burn_pre_loss - burn_cur_loss < 0:
                        break
                    burn_pre_loss = burn_cur_loss
                    burn_cur_loss = burn_num_examples = 0

                sub_iter += 1

            # print(sub_iter)

            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()

            loss, loss_rc, loss_kl = vae.loss(batch_data,
                                              kl_weight,
                                              nsamples=args.nsamples)

            loss = loss.mean(dim=-1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

            loss_rc = loss_rc.sum()
            loss_kl = loss_kl.sum()

            if not aggressive_flag:
                enc_optimizer.step()

            dec_optimizer.step()

            report_rec_loss += loss_rc.item()
            report_kl_loss += loss_kl.item()

            if iter_ % log_niter == 0:
                train_loss = (report_rec_loss +
                              report_kl_loss) / report_num_examples
                if aggressive_flag or epoch == 0:
                    vae.eval()
                    with torch.no_grad():
                        mi = calc_mi(vae, val_loader)
                        au, _ = calc_au(vae, val_loader)

                    vae.train()

                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
                           'au %d, time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_examples, mi,
                           report_rec_loss / report_num_examples, au, time.time() - start))
                else:
                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                          'time elapsed %.2fs' %
                          (epoch, iter_, train_loss, report_kl_loss / report_num_examples,
                          report_rec_loss / report_num_examples, time.time() - start))
                sys.stdout.flush()

                report_rec_loss = report_kl_loss = 0
                report_num_examples = 0

            iter_ += 1

            if aggressive_flag and (iter_ % len(train_loader)) == 0:
                vae.eval()
                cur_mi = calc_mi(vae, val_loader)
                vae.train()
                if cur_mi - best_mi < 0:
                    mi_not_improved += 1
                    if mi_not_improved == 5:
                        aggressive_flag = False
                        print("STOP BURNING")

                else:
                    best_mi = cur_mi

                pre_mi = cur_mi

        print('kl weight %.4f' % kl_weight)
        print('epoch: %d, VAL' % epoch)

        vae.eval()

        with torch.no_grad():
            loss, nll, kl = test(vae, val_loader, "VAL", args)
            au, au_var = calc_au(vae, val_loader)
            print("%d active units" % au)
            # print(au_var)

        if loss < best_loss:
            print('update best loss')
            best_loss = loss
            best_nll = nll
            best_kl = kl
            torch.save(vae.state_dict(), args.save_path)

        if loss > best_loss:
            opt_dict["not_improved"] += 1
            if opt_dict["not_improved"] >= decay_epoch:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                vae.load_state_dict(torch.load(args.save_path))
                decay_cnt += 1
                print('new lr: %f' % opt_dict["lr"])
                enc_optimizer = optim.Adam(vae.encoder.parameters(),
                                           lr=opt_dict["lr"])
                dec_optimizer = optim.Adam(vae.decoder.parameters(),
                                           lr=opt_dict["lr"])
        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, nll, kl = test(vae, test_loader, "TEST", args)

        vae.train()

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))
    vae.eval()
    with torch.no_grad():
        loss, nll, kl = test(vae, test_loader, "TEST", args)
        au, au_var = calc_au(vae, test_loader)
        print("%d active units" % au)
        # print(au_var)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=50,
                                              shuffle=True)

    with torch.no_grad():
        calc_iwnll(vae, test_loader, args)
Example #26
0
     type=str,
     default="mnist",
     metavar='S',
     help=
     'Training mode selection. Choices: mnist, synthetic_timeseries, cell_timeseries. (default: mnist)'
 )
 args = parser.parse_args()
 model_filepath = "model-{}.pth".format(args.train_mode)
 root_path = "results/{}".format(args.train_mode)
 try:
     os.makedirs(root_path)
 except:
     pass
 is_cuda = not args.no_cuda
 device = torch.device("cuda" if is_cuda else "cpu")
 model = VAE(dropout=args.dropout,
             input_dim=input_dims[args.train_mode]).to(device)
 try:
     model = load_model(model_filepath, model)
     logger.info("Loading model from {}".format(model_filepath))
 except:
     logger.info("Creating VAE model from scratch")
     model = VAE(dropout=args.dropout,
                 input_dim=input_dims[args.train_mode]).to(device)
 if args.train_mode == 'mnist':
     train_mnist(model, device, args.epochs, root_path)
 elif args.train_mode == "synthetic_timeseries":
     model.decoder.sigmoid = False  # disable sigmoid from the final decoder layer
     train_synthetic_timeseries(model, device, args.epochs, root_path)
 elif args.train_mode == "cell_timeseries":
     model.decoder.sigmoid = False  # disable sigmoid from the final decoder layer
     train_cell_timeseries(model, device, args.epochs, root_path)
Example #27
0
class VAESampler:
    def __init__(self, decode_from, params, cuda=False):
        self.decode_from = decode_from
        self.params = params
        params.enc_nh = params.dec_nh  # not sure why this is necessary...

        self.train_data = MonoTextData(params.train_data, label=False)
        self.vocab = self.train_data.vocab
        self.vocab_size = len(self.vocab)

        # do I need these?
        model_init = uniform_initializer(0.01)
        emb_init = uniform_initializer(0.1)

        params.device = self.device = torch.device("cuda" if cuda else "cpu")

        self.encoder = LSTMEncoder(params, self.vocab_size, model_init,
                                   emb_init)
        self.decoder = LSTMDecoder(params, self.vocab, model_init, emb_init)

        self.vae = VAE(self.encoder, self.decoder, params).to(params.device)

        # assuming models were trained on a gpu...
        if cuda:
            self.vae.load_state_dict(torch.load(self.decode_from))
        else:
            self.vae.load_state_dict(
                torch.load(self.decode_from, map_location='cpu'))

    def to_s(self, decoded):
        return [' '.join(item) for item in decoded]

    def beam(self, z, K=5):
        decoded_batch = self.vae.decoder.beam_search_decode(z, K)
        return self.to_s(decoded_batch)

    def sample(self, z, temperature=1.0):
        decoded_batch = self.vae.decoder.sample_decode(z, temperature)
        return self.to_s(decoded_batch)

    def greedy(self, z):
        decoded_batch = self.vae.decoder.greedy_decode(z)
        return self.to_s(decoded_batch)

    def str2ids(self, s):
        "encode string s as list of word ids"
        raise NotImplemented

    def encode(self, t):
        """
        Returns (z, mu, log_var) from encoder given list of strings.

        z is a sample from gaussian specified with (mu, log_var)
        """
        str_ids = []
        for s in t:
            ids = self.str2ids(s)
            str_ids.append(ids)
        tensor = self.train_data._to_tensor(str_ids, True, self.device)[0]
        z, (mu, log_var) = self.vae.encoder.sample(tensor, 1)
        return z, mu, log_var

    def z(self, t):
        "return sampled latent zs for list of strings t"
        z, mu, logvar = self.encode(t)
        return z.squeeze(1)

    def mu(self, t):
        "return mean of latent gaussian for list of strings t"
        z, mu, logvar = self.encode(t)
        return mu.squeeze(1)
Example #28
0
def main(args):
    global logging
    logging = create_exp_dir(args.exp_dir, scripts_to_save=[])

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    vocab = {}
    with open(args.vocab_file) as fvocab:
        for i, line in enumerate(fvocab):
            vocab[line.strip()] = i

    vocab = VocabEntry(vocab)

    train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab)

    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    logging('Train data: %d samples' % len(train_data))
    logging('finish reading datasets, vocab size is %d' % len(vocab))
    logging('dropped sentences: %d' % train_data.dropped)
    #sys.stdout.flush()

    log_niter = max(1, (len(train_data) //
                        (args.batch_size * args.update_every)) // 10)

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.fb == 3:
        encoder = DeltaGaussianLSTMEncoder(args, vocab_size, model_init,
                                           emb_init)
        args.enc_nh = args.dec_nh
    elif args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    if args.load_path:
        loaded_state_dict = torch.load(args.load_path)
        #curr_state_dict = vae.state_dict()
        #curr_state_dict.update(loaded_state_dict)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

    # if args.eval:
    #     logging('begin evaluation')
    #     vae.load_state_dict(torch.load(args.load_path))
    #     vae.eval()
    #     with torch.no_grad():
    #         test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
    #                                                       device=device,
    #                                                       batch_first=True)

    #         test(vae, test_data_batch, test_labels_batch, "TEST", args)
    #         au, au_var = calc_au(vae, test_data_batch)
    #         logging("%d active units" % au)
    #         # print(au_var)

    #         test_data_batch = test_data.create_data_batch(batch_size=1,
    #                                                       device=device,
    #                                                       batch_first=True)
    #         calc_iwnll(vae, test_data_batch, args)

    #     return

    if args.discriminator == "linear":
        discriminator = LinearDiscriminator(args, vae.encoder).to(device)
    elif args.discriminator == "mlp":
        discriminator = MLPDiscriminator(args, vae.encoder).to(device)

    if args.opt == "sgd":
        optimizer = optim.SGD(discriminator.parameters(),
                              lr=args.lr,
                              momentum=args.momentum)
        opt_dict['lr'] = args.lr
    elif args.opt == "adam":
        optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
        opt_dict['lr'] = 0.001
    else:
        raise ValueError("optimizer not supported")

    iter_ = decay_cnt = 0
    best_loss = 1e4
    # best_kl = best_nll = best_ppl = 0
    # pre_mi = 0
    discriminator.train()
    start = time.time()

    # kl_weight = args.kl_start
    # if args.warm_up > 0:
    #     anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size))
    # else:
    #     anneal_rate = 0

    # dim_target_kl = args.target_kl / float(args.nz)

    train_data_batch, train_labels_batch = train_data.create_data_batch_labels(
        batch_size=args.batch_size, device=device, batch_first=True)

    val_data_batch, val_labels_batch = val_data.create_data_batch_labels(
        batch_size=128, device=device, batch_first=True)

    test_data_batch, test_labels_batch = test_data.create_data_batch_labels(
        batch_size=128, device=device, batch_first=True)

    acc_cnt = 1
    acc_loss = 0.
    for epoch in range(args.epochs):
        report_loss = 0
        report_correct = report_num_words = report_num_sents = 0
        acc_batch_size = 0
        optimizer.zero_grad()
        for i in np.random.permutation(len(train_data_batch)):

            batch_data = train_data_batch[i]
            if batch_data.size(0) < 2:
                continue
            batch_labels = train_labels_batch[i]
            batch_labels = [int(x) for x in batch_labels]

            batch_labels = torch.tensor(batch_labels,
                                        dtype=torch.long,
                                        requires_grad=False,
                                        device=device)

            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size
            report_num_sents += batch_size
            acc_batch_size += batch_size

            # (batch_size)
            loss, correct = discriminator.get_performance(
                batch_data, batch_labels)

            acc_loss = acc_loss + loss.sum()

            if acc_cnt % args.update_every == 0:
                acc_loss = acc_loss / acc_batch_size
                acc_loss.backward()

                torch.nn.utils.clip_grad_norm_(discriminator.parameters(),
                                               clip_grad)

                optimizer.step()
                optimizer.zero_grad()

                acc_cnt = 0
                acc_loss = 0
                acc_batch_size = 0

            acc_cnt += 1
            report_loss += loss.sum().item()
            report_correct += correct

            if iter_ % log_niter == 0:
                #train_loss = (report_rec_loss  + report_kl_loss) / report_num_sents
                train_loss = report_loss / report_num_sents


                logging('epoch: %d, iter: %d, avg_loss: %.4f, acc %.4f,' \
                       'time %.2fs' %
                       (epoch, iter_, train_loss, report_correct / report_num_sents,
                        time.time() - start))

                #sys.stdout.flush()

            iter_ += 1

        logging('lr {}'.format(opt_dict["lr"]))
        print(report_num_sents)
        discriminator.eval()

        with torch.no_grad():
            loss, acc = test(discriminator, val_data_batch, val_labels_batch,
                             "VAL", args)
            # print(au_var)

        if loss < best_loss:
            logging('update best loss')
            best_loss = loss
            best_acc = acc
            print(args.save_path)
            torch.save(discriminator.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict[
                    "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                discriminator.load_state_dict(torch.load(args.save_path))
                logging('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                if args.opt == "sgd":
                    optimizer = optim.SGD(discriminator.parameters(),
                                          lr=opt_dict["lr"],
                                          momentum=args.momentum)
                    opt_dict['lr'] = opt_dict["lr"]
                elif args.opt == "adam":
                    optimizer = optim.Adam(discriminator.parameters(),
                                           lr=opt_dict["lr"])
                    opt_dict['lr'] = opt_dict["lr"]
                else:
                    raise ValueError("optimizer not supported")

        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, acc = test(discriminator, test_data_batch,
                                 test_labels_batch, "TEST", args)

        discriminator.train()

    # compute importance weighted estimate of log p(x)
    discriminator.load_state_dict(torch.load(args.save_path))
    discriminator.eval()

    with torch.no_grad():
        loss, acc = test(discriminator, test_data_batch, test_labels_batch,
                         "TEST", args)
def main(args):
    class uniform_initializer(object):
        def __init__(self, stdv):
            self.stdv = stdv

        def __call__(self, tensor):
            nn.init.uniform_(tensor, -self.stdv, self.stdv)

    class xavier_normal_initializer(object):
        def __call__(self, tensor):
            nn.init.xavier_normal_(tensor)

    if args.cuda:
        print('using cuda')

    print(args)

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    train_data = MonoTextData(args.train_data, label=args.label)

    vocab = train_data.vocab
    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    print('Train data: %d samples' % len(train_data))
    print('finish reading datasets, vocab size is %d' % len(vocab))
    print('dropped sentences: %d' % train_data.dropped)
    sys.stdout.flush()

    log_niter = (len(train_data) // args.batch_size) // 10

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    if args.enc_type == 'lstm':
        encoder = LSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)

    device = torch.device("cuda" if args.cuda else "cpu")
    args.device = device
    vae = VAE(encoder, decoder, args).to(device)

    if args.eval:
        print('begin evaluation')
        vae.load_state_dict(torch.load(args.load_path))
        vae.eval()
        with torch.no_grad():
            test_data_batch = test_data.create_data_batch(
                batch_size=args.batch_size, device=device, batch_first=True)

            test(vae, test_data_batch, "TEST", args)
            au, au_var = calc_au(vae, test_data_batch)
            print("%d active units" % au)
            # print(au_var)

            test_data_batch = test_data.create_data_batch(batch_size=1,
                                                          device=device,
                                                          batch_first=True)
            calc_iwnll(vae, test_data_batch, args)

        return

    enc_optimizer = optim.SGD(vae.encoder.parameters(),
                              lr=1.0,
                              momentum=args.momentum)
    dec_optimizer = optim.SGD(vae.decoder.parameters(),
                              lr=1.0,
                              momentum=args.momentum)
    opt_dict['lr'] = 1.0

    iter_ = decay_cnt = 0
    best_loss = 1e4
    best_kl = best_nll = best_ppl = 0
    pre_mi = 0
    aggressive_flag = True if args.aggressive else False
    vae.train()
    start = time.time()

    kl_weight = args.kl_start
    anneal_rate = (1.0 - args.kl_start) / (args.warm_up *
                                           (len(train_data) / args.batch_size))

    train_data_batch = train_data.create_data_batch(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch = val_data.create_data_batch(batch_size=args.batch_size,
                                                device=device,
                                                batch_first=True)

    test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
                                                  device=device,
                                                  batch_first=True)
    for epoch in range(args.epochs):
        report_kl_loss = report_rec_loss = 0
        report_num_words = report_num_sents = 0
        for i in np.random.permutation(len(train_data_batch)):
            batch_data = train_data_batch[i]
            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size

            report_num_sents += batch_size

            # kl_weight = 1.0
            kl_weight = min(1.0, kl_weight + anneal_rate)

            sub_iter = 1
            batch_data_enc = batch_data
            burn_num_words = 0
            burn_pre_loss = 1e4
            burn_cur_loss = 0
            while aggressive_flag and sub_iter < 100:

                enc_optimizer.zero_grad()
                dec_optimizer.zero_grad()

                burn_batch_size, burn_sents_len = batch_data_enc.size()
                burn_num_words += (burn_sents_len - 1) * burn_batch_size

                loss, loss_rc, loss_kl = vae.loss(batch_data_enc,
                                                  kl_weight,
                                                  nsamples=args.nsamples)

                burn_cur_loss += loss.sum().item()
                loss = loss.mean(dim=-1)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

                enc_optimizer.step()

                id_ = np.random.random_integers(0, len(train_data_batch) - 1)

                batch_data_enc = train_data_batch[id_]

                if sub_iter % 15 == 0:
                    burn_cur_loss = burn_cur_loss / burn_num_words
                    if burn_pre_loss - burn_cur_loss < 0:
                        break
                    burn_pre_loss = burn_cur_loss
                    burn_cur_loss = burn_num_words = 0

                sub_iter += 1

                # if sub_iter >= 30:
                #     break

            # print(sub_iter)

            enc_optimizer.zero_grad()
            dec_optimizer.zero_grad()

            loss, loss_rc, loss_kl = vae.loss(batch_data,
                                              kl_weight,
                                              nsamples=args.nsamples)

            loss = loss.mean(dim=-1)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad)

            loss_rc = loss_rc.sum()
            loss_kl = loss_kl.sum()

            if not aggressive_flag:
                enc_optimizer.step()

            dec_optimizer.step()

            report_rec_loss += loss_rc.item()
            report_kl_loss += loss_kl.item()

            if iter_ % log_niter == 0:
                train_loss = (report_rec_loss +
                              report_kl_loss) / report_num_sents
                if aggressive_flag or epoch == 0:
                    vae.eval()
                    with torch.no_grad():
                        mi = calc_mi(vae, val_data_batch)
                        au, _ = calc_au(vae, val_data_batch)
                    vae.train()

                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \
                           'au %d, time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi,
                           report_rec_loss / report_num_sents, au, time.time() - start))
                else:
                    print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \
                           'time elapsed %.2fs' %
                           (epoch, iter_, train_loss, report_kl_loss / report_num_sents,
                           report_rec_loss / report_num_sents, time.time() - start))

                sys.stdout.flush()

                report_rec_loss = report_kl_loss = 0
                report_num_words = report_num_sents = 0

            iter_ += 1

            if aggressive_flag and (iter_ % len(train_data_batch)) == 0:
                vae.eval()
                cur_mi = calc_mi(vae, val_data_batch)
                vae.train()
                print("pre mi:%.4f. cur mi:%.4f" % (pre_mi, cur_mi))
                if cur_mi - pre_mi < 0:
                    aggressive_flag = False
                    print("STOP BURNING")

                pre_mi = cur_mi

        print('kl weight %.4f' % kl_weight)

        vae.eval()
        with torch.no_grad():
            loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args)
            au, au_var = calc_au(vae, val_data_batch)
            print("%d active units" % au)
            # print(au_var)

        if loss < best_loss:
            print('update best loss')
            best_loss = loss
            best_nll = nll
            best_kl = kl
            best_ppl = ppl
            torch.save(vae.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict["not_improved"] >= decay_epoch and epoch >= 15:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                vae.load_state_dict(torch.load(args.save_path))
                print('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                enc_optimizer = optim.SGD(vae.encoder.parameters(),
                                          lr=opt_dict["lr"],
                                          momentum=args.momentum)
                dec_optimizer = optim.SGD(vae.decoder.parameters(),
                                          lr=opt_dict["lr"],
                                          momentum=args.momentum)

        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST",
                                             args)

        vae.train()

    # compute importance weighted estimate of log p(x)
    vae.load_state_dict(torch.load(args.save_path))

    vae.eval()
    with torch.no_grad():
        loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args)
        au, au_var = calc_au(vae, test_data_batch)
        print("%d active units" % au)
        # print(au_var)

    test_data_batch = test_data.create_data_batch(batch_size=1,
                                                  device=device,
                                                  batch_first=True)
    with torch.no_grad():
        calc_iwnll(vae, test_data_batch, args)
Example #30
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train_data_file", default=None, type=str, required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--eval_data_file", default=None, type=str,
                        help="An input evaluation data file to evaluate the perplexity on (a text file).")
    parser.add_argument("--checkpoint_dir", default=None, type=str, required=True,
                        help="The directory where checkpoints are saved.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.")

    ## Variational auto-encoder
    parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
    parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.")
    parser.add_argument("--num_sents", default=10, type=int, help="Total sentences to generate.")


    ## Encoder options
    parser.add_argument("--encoder_model_type", default="bert", type=str,
                        help="The encoder model architecture to be fine-tuned.")
    parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
                        help="The encoder model checkpoint for weights initialization.")
    parser.add_argument("--encoder_config_name", default="", type=str,
                        help="Optional pretrained config name or path if not the same as model_name_or_path")
    parser.add_argument("--encoder_tokenizer_name", default="", type=str,
                        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")

    ## Decoder options
    parser.add_argument("--decoder_model_type", default="gpt2", type=str,
                        help="The decoder model architecture to be fine-tuned.")
    parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
                        help="The decoder model checkpoint for weights initialization.")
    parser.add_argument("--decoder_config_name", default="", type=str,
                        help="Optional pretrained config name or path if not the same as model_name_or_path")
    parser.add_argument("--decoder_tokenizer_name", default="", type=str,
                        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")


    parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument('--gloabl_step_eval', type=int, default=661,
                        help="Evaluate the results at the given global step")

    parser.add_argument("--max_seq_length", default=512, type=int,
                        help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")


    ## Variational auto-encoder
    parser.add_argument("--nz", default=32, type=int,
                        help="Latent space dimension.")

    parser.add_argument("--prompt", type=str, default="")
    parser.add_argument("--padding_text", type=str, default="")
    parser.add_argument("--length", type=int, default=20)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--top_p", type=float, default=0.9)
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")

    parser.add_argument("--block_size", default=-1, type=int,
                        help="Optional input sequence length after tokenization."
                             "The training dataset will be truncated in block of this size for training."
                             "Default to the model max input length for single sentence inputs (take into account special tokens).")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--use_philly", action='store_true',
                        help="Use Philly for computing.")

    args = parser.parse_args()

    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()

    set_seed(args)


    args.encoder_model_type = args.encoder_model_type.lower()
    args.decoder_model_type = args.decoder_model_type.lower()


    global_step = args.gloabl_step_eval

    output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
    output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step)) 
    checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
    logger.info("Evaluate the following checkpoints: %s", checkpoints)

    # Load a trained Encoder model and vocabulary that you have fine-tuned
    encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
    model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
    tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)

    model_encoder.to(args.device)
    if args.block_size <= 0:
        args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)

    # Load a trained Decoder model and vocabulary that you have fine-tuned
    decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
    model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
    tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
    model_decoder.to(args.device)
    if args.block_size <= 0:
        args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)

    # pdb.set_trace()
    # Chunyuan: Add Padding token to GPT2
    special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
    num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens to GPT2')
    model_decoder.resize_token_embeddings(len(tokenizer_decoder))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
    assert tokenizer_decoder.pad_token == '<PAD>'

    
    # Evaluation
    model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device)

    if not os.path.exists(args.output_dir): os.makedirs(args.output_dir)
    args.output_generation_file = os.path.join(args.output_dir, f"generation_from_vae_prior_t{args.temperature}_p{args.top_p}.txt")
    # args.output_generation_file = args.train_data_file
    result = evaluate_generation_fromp_prior(model_vae, tokenizer_decoder, args)

    
    bleu5 = Bleu(test_text= args.output_generation_file,
                 real_text=args.eval_data_file,
                 num_real_sentences=args.num_sents,
                 num_fake_sentences=args.num_sents,
                 gram=5).get_score()
    logger.info(f'The bleu score is {bleu5}')

    sbleu5 = SelfBleu(test_text= args.output_generation_file,
                 num_sentences=args.num_sents,
                 gram=5).get_score()
    logger.info(f'The self-bleu score is {sbleu5}')

    args.eval_results_file = os.path.join(args.output_dir, f"eval_results_t{args.temperature}_p{args.top_p}.txt")
    eval_results = {'bleu5':bleu5 , 'sbleu5':sbleu5}
    with open(args.eval_results_file, "w") as writer:
        logger.info("***** SHOW the quantative evalution results *****")
        for key in sorted(eval_results.keys()):
            writer.write("%s %s" % (key, str(eval_results[key]))  )