def __init__(self, decode_from, params, cuda=False): self.decode_from = decode_from self.params = params params.enc_nh = params.dec_nh # not sure why this is necessary... self.train_data = MonoTextData(params.train_data, label=False) self.vocab = self.train_data.vocab self.vocab_size = len(self.vocab) # do I need these? model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) params.device = self.device = torch.device("cuda" if cuda else "cpu") self.encoder = LSTMEncoder(params, self.vocab_size, model_init, emb_init) self.decoder = LSTMDecoder(params, self.vocab, model_init, emb_init) self.vae = VAE(self.encoder, self.decoder, params).to(params.device) # assuming models were trained on a gpu... if cuda: self.vae.load_state_dict(torch.load(self.decode_from)) else: self.vae.load_state_dict( torch.load(self.decode_from, map_location='cpu'))
def main(): # Load MNIST image dataset mnist_train_data = datasets.MNIST( '/home/ajays/Downloads/',download=True,transform=transforms.ToTensor() ) mnist_test_data = datasets.MNIST('/home/ajays/Downloads/',train=False,download=True) train_loader = torch.utils.data.DataLoader( mnist_train_data, batch_size = batch_size, shuffle=True ) # Instantiation vae = VAE(n_inputs=32) # ********************* # IMAGE VAE TRAINING # ********************* # plot before training # o_before, mu, logvar = vae(mnist_train_data[0][0].reshape((1,1,28,28))) # plt.imshow(o_before.detach().numpy().reshape((28,28))) # plt.show() # train vae.load_state_dict(torch.load(LOAD_PATH)) #vae = train_image_vae(vae, train_loader) # After training # o_after, mu, logvar = vae(example[0].reshape((1,1,28,28))) o_after = vae.decode(torch.randn((128))) plt.imshow(o_after.detach().numpy().reshape((28,28))) plt.show()
def val_test(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, valid_loader, test_loader = train_util.get_dataloaders(args) recons_input_img = train_util.log_input_img_grid(test_loader, writer) input_dim = 3 model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type) # if torch.cuda.device_count() > 1 and args.device == "cuda": # model = torch.nn.DataParallel(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) discriminators = {} if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, 4), file=sys.stdout): val_loss_dict, z = train_util.test(get_losses, model, valid_loader, args, discriminators, True) # if args.weights == "init" and epoch==1: # epoch+=1 # break # print(z.shape) train_util.log_recons_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) print(val_loss_dict)
def get_model(model, hidden_size, k, num_channels, resolution, num_classes): if model == "vqvae": if hidden_size == 256: CKPT_DIR = "models/imagenet/best.pt" elif hidden_size == 128: CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size) else: CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size) #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size) model = VectorQuantizedVAE(num_channels, hidden_size, k) elif model == "vae": CKPT_DIR = f"models/imagenet_hs_128_{hidden_size}_vae.pt" model = VAE(num_channels, hidden_size, hidden_size) elif model == "aae": CKPT_DIR = f"models/aae/imagenet_hs_32_{hidden_size}/best.pt" model = AAE(32, num_channels, hidden_size) else: model = None imgnetclassifier = ImgnetClassifier(model, hidden_size, resolution) if hidden_size > 0: ckpt = torch.load(CKPT_DIR) model.load_state_dict(ckpt) return imgnetclassifier
def get_model(model, hidden_size, num_channels, resolution, enc_type, dec_type, num_classes, k=512): CKPT_DIR = f"models/{model}_{args.recons_loss}/{args.train_dataset}/depth_{enc_type}_{dec_type}_hs_{args.img_res}_{hidden_size}/best.pt" if model == "vqvae": #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size) model = VectorQuantizedVAE(num_channels, hidden_size, k, enc_type, dec_type) elif model == "vae": model = VAE(num_channels, hidden_size, enc_type, dec_type) elif model == "acai": model = ACAI(resolution, num_channels, hidden_size, enc_type, dec_type) else: model = None if model != "supervised": imgclassifier = ImgClassifier(model, hidden_size, resolution, num_classes) else: imgclassifier = SupervisedImgClassifier(hidden_size, enc_type, resolution, num_classes) if hidden_size > 0: ckpt = torch.load(CKPT_DIR) model.load_state_dict(ckpt["model"]) return imgclassifier
def get_vae_recons(loader, hidden_size=256): model = VAE(3, hidden_size, hidden_size).to(DEVICE) ckpt = torch.load("./models/imagenet_hs_128_256_vae.pt") model.load_state_dict(ckpt) args = type('', (), {})() args.device = DEVICE gen_img, _ = next(iter(loader)) # grid = make_grid(gen_img.cpu(), nrow=8) # torchvision.utils.save_image(grid, "hs_{}_recons.png".format(hidden_size)) #exit() reconstruction = vae.generate_samples(gen_img, model, args) grid = make_grid(reconstruction.cpu(), nrow=8) return grid
def get_vae(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, beta=1): ArgsObj = namedtuple("Args", ["latent_size", "device", "fb_mode", "beta"]) args = ArgsObj(latent_size=LATENT_SIZE_LARGE, device=get_device(), fb_mode=0, beta=beta) checkpoint_full_dir = os.path.join(OUTPUT_DIR, "checkpoint-full-31250") if not torch.cuda.is_available(): checkpoint = torch.load(os.path.join(checkpoint_full_dir, "training.bin"), map_location="cpu") else: checkpoint = torch.load( os.path.join(checkpoint_full_dir, "training.bin")) model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args) model_vae.load_state_dict(checkpoint["model_state_dict"]) # logger.info("Pre-trained Optimus is successfully loaded") model_vae.to(args.device) return model_vae
def create_model(args, vocab): # build initializers model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) # build encoder if args.enc_type == 'lstm': encoder = LSTMEncoder(args, args.vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") # build decoder decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(args.device) return vae
def get_model(ae_type, hidden_size, k, num_channels): if ae_type == "vqvae": if hidden_size==256: CKPT_DIR = "models/imagenet/best.pt" elif hidden_size==128: CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size) else: CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size) model = VectorQuantizedVAE(num_channels, hidden_size, k) imgnetclassifier = ImgnetClassifier(model, hidden_size) elif ae_type == "vae": CKPT_DIR = "models/imagenet_vae.pt" model = VAE(num_channels, hidden_size, 4096) imgnetclassifier = ImgnetClassifier(model, 4) ckpt = torch.load(CKPT_DIR) model.load_state_dict(ckpt) return imgnetclassifier
def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument("--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file).") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--dataset", default=None, type=str, help="The dataset.") ## Other parameters parser.add_argument("--eval_data_file", default=None, type=str, help="An optional input evaluation data file to evaluate the perplexity on (a text file).") parser.add_argument("--ExpName", default="", type=str, help="The experiment name used in Azure Table.") ## Encoder options parser.add_argument("--encoder_model_type", default="bert", type=str, help="The encoder model architecture to be fine-tuned.") parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str, help="The encoder model checkpoint for weights initialization.") parser.add_argument("--encoder_config_name", default="", type=str, help="Optional pretrained config name or path if not the same as model_name_or_path") parser.add_argument("--encoder_tokenizer_name", default="", type=str, help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") ## Decoder options parser.add_argument("--decoder_model_type", default="gpt2", type=str, help="The decoder model architecture to be fine-tuned.") parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str, help="The decoder model checkpoint for weights initialization.") parser.add_argument("--decoder_config_name", default="", type=str, help="Optional pretrained config name or path if not the same as model_name_or_path") parser.add_argument("--decoder_tokenizer_name", default="", type=str, help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") ## Variational auto-encoder parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.") parser.add_argument("--use_deterministic_connect", action='store_true', help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.") parser.add_argument("--use_beta_schedule", action='store_true', help="Use cyclical beta schedule for auto-encoders.") ## Objective functions parser.add_argument("--mlm", action='store_true', help="Train with masked-language modeling loss instead of language modeling.") parser.add_argument("--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss") parser.add_argument("--beta", type=float, default=1.0, help="The weighting hyper-parameter of the KL term in VAE") parser.add_argument("--cache_dir", default="", type=str, help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)") parser.add_argument("--max_seq_length", default=512, type=int, help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length") parser.add_argument("--block_size", default=-1, type=int, help="Optional input sequence length after tokenization." "The training dataset will be truncated in block of this size for training." "Default to the model max input length for single sentence inputs (take into account special tokens).") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--evaluate_during_training", action='store_true', help="Run evaluation during training at each logging step.") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") # Training Schedules parser.add_argument("--ratio_increase", default=0.25, type=float, help="Learning schedule, the percentage for the annealing stage.") parser.add_argument("--ratio_zero", default=0.25, type=float, help="Learning schedule, the percentage for the pure auto-encoding stage.") parser.add_argument("--fb_mode", default=0, type=int, help="free bit training mode.") parser.add_argument("--dim_target_kl", default=3.0, type=float, help="dim_target_kl free bit training mode.") parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--use_philly", action='store_true', help="Use Philly for computing.") ## IO: Logging and Saving parser.add_argument('--logging_steps', type=int, default=50, help="Log every X updates steps.") parser.add_argument('--save_steps', type=int, default=50, help="Save checkpoint every X updates steps.") parser.add_argument("--eval_all_checkpoints", action='store_true', help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number") parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--overwrite_output_dir', action='store_true', help="Overwrite the content of the output directory") parser.add_argument('--overwrite_cache', action='store_true', help="Overwrite the cached training and evaluation sets") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gloabl_step_eval', type=int, default=661, help="Evaluate the results at the given global step") # Precision & Distributed Training parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") parser.add_argument('--fp16_opt_level', type=str, default='O1', help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") parser.add_argument('--world-size', default=ompi_size(), type=int, help='number of distributed processes') parser.add_argument('--dist-url', default='tcp://' + get_master_ip() + ':23456', type=str, help='url used to set up distributed training') parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') parser.add_argument('--port', type=str, default='51115', help="Port") args = parser.parse_args() args.dist_url = 'tcp://' + get_master_ip() + ':' + args.port # Setup logging logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) rank_node = ompi_rank() args.distributed = args.world_size > 1 logger.info("Rank {} distributed: {}".format(rank_node, args.distributed)) if args.decoder_model_type in ["bert", "roberta"] and not args.mlm: raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " "flag (masked language modeling).") if args.eval_data_file is None and args.do_eval: raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " "or remove the --do_eval argument.") if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) # Setup distant debugging if needed if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() if args.distributed: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=ompi_rank(), group_name='mtorch') logger.info("World Size is {}, Backend is {}, Init Method is {}, rank is {}".format(args.world_size, args.dist_backend, args.dist_url, ompi_rank())) gpus = list(gpu_indices()) args.n_gpu = len(gpus) args.local_rank = ompi_rank() #gpus[0] torch.cuda.set_device(gpus[0]) device = torch.device("cuda", gpus[0]) args.device = device logger.info('Rank {}, gpus: {}, get_rank: {}'.format(rank_node, gpus, torch.distributed.get_rank())) logger.info(f'Local rank is {args.local_rank}, {rank_node}') logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero) table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size) if ompi_rank() == 0: try: ts.create_table(table_name) except: pass # Set seed set_seed(args) # Load pretrained model and tokenizer #if args.local_rank not in [-1, 0]: torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab ## Encoder encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type] encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path) tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case) if args.block_size <= 0: args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence) model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size) # model_encoder.to(args.device) ## Decoder decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type] decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path) tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case) if args.block_size <= 0: args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence) setattr(decoder_config, "latent_size", args.latent_size) model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size) # Chunyuan: Add Padding token to GPT2 special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'} num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict) print('We have added', num_added_toks, 'tokens to GPT2') model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. assert tokenizer_decoder.pad_token == '<PAD>' #model_decoder.to(args.device) model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) # #model_vae.cuda() # Distributed training (should be after apex fp16 initialization) if args.distributed: # model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=gpus, output_device=args.local_rank, find_unused_parameters=True) model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=gpus) elif args.n_gpu > 1: model_vae = torch.nn.DataParallel(model_vae)#.to(args.device) # on_gpu = next(model_vae.parameters()).is_cuda #if args.local_rank == 0: torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab logger.info("Training/evaluation parameters %s", args) global_step=0 if args.do_train: #if args.local_rank not in [-1, 0]: torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False) #if args.local_rank == 0: torch.distributed.barrier() global_step, tr_loss = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name) logger.info("Rank %d, global_step = %s, average loss = %s", ompi_rank(), global_step, tr_loss) # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained() if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Create output directory if needed # Save model checkpoint output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step)) output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step)) if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]: os.makedirs(output_encoder_dir) if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]: os.makedirs(output_decoder_dir) logger.info("Saving encoder model checkpoint to %s", output_encoder_dir) logger.info("Saving decoder model checkpoint to %s", output_decoder_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder # Take care of distributed/parallel training model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder # Take care of distributed/parallel training # Good practice: save your training arguments together with the trained model if args.use_philly: save_solid = False while not save_solid: try: model_encoder_to_save.save_pretrained(output_encoder_dir) torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin')) save_solid = True except: pass else: model_encoder_to_save.save_pretrained(output_encoder_dir) torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin')) if args.use_philly: save_solid = False while not save_solid: try: model_decoder_to_save.save_pretrained(output_decoder_dir) torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin')) save_solid = True except: pass else: model_decoder_to_save.save_pretrained(output_decoder_dir) torch.save(args, os.path.join(output_decoder_dir, 'training_encoder_args.bin'))
def __init__(self, config, vocab, rev_vocab, PAD_token=0): super(multiVAE, self).__init__() assert rev_vocab['<pad>'] == PAD_token self.vocab = vocab self.vocab_size = len(self.vocab) self.embed_size = config.emb_size self.hidden_size = config.n_hidden self.bow_size = config.bow_size self.rev_vocab = rev_vocab self.dropout = config.dropout self.go_id = self.rev_vocab["<s>"] self.eos_id = self.rev_vocab["</s>"] self.maxlen = config.maxlen self.clip = config.clip self.temp = config.temp self.full_kl_step = config.full_kl_step self.z_size = config.z_size self.init_w = config.init_weight self.softmax = nn.Softmax(dim=1) self.bidirectional = config.bidirectional self.lr_ae = config.lr_ae self.lr_vae = config.lr_vae # 如果LSTM双向,则两个方向拼接在一起 self.encoder_output_size = self.hidden_size * (1 + int(self.bidirectional)) # 标题和首句拼接在一起, self.context_dim = self.encoder_output_size * 2 self.decoder_input_size = self.z_size # build components self.layers = nn.ModuleDict() self.layers["embedder"] = nn.Embedding(self.vocab_size, self.embed_size, padding_idx=PAD_token) # 对title, 每一句诗做编码, 默认双向LSTM,将最终的一维拼在一起 self.layers["seq_encoder"] = Encoder(embedder=self.layers["embedder"], input_size=config.emb_size, hidden_size=config.n_hidden, bidirectional=self.bidirectional, n_layers=config.n_layers, noise_radius=config.noise_radius) # 先验网络 self.layers["neg_vae"] = VAE(target_size=self.encoder_output_size, z_size=self.z_size, dropout=self.dropout, init_weight=self.init_w) self.layers["neu_vae"] = VAE(target_size=self.encoder_output_size, z_size=self.z_size, dropout=self.dropout, init_weight=self.init_w) self.layers["pos_vae"] = VAE(target_size=self.encoder_output_size, z_size=self.z_size, dropout=self.dropout, init_weight=self.init_w) # 词 Bow loss self.layers["bow_project_pos"] = nn.Sequential( nn.Linear(self.decoder_input_size, self.bow_size), nn.LeakyReLU(), nn.Dropout(self.dropout), nn.Linear(self.bow_size, self.vocab_size) ) self.layers["bow_project_neu"] = nn.Sequential( nn.Linear(self.decoder_input_size, self.bow_size), nn.LeakyReLU(), nn.Dropout(self.dropout), nn.Linear(self.bow_size, self.vocab_size) ) self.layers["bow_project_neg"] = nn.Sequential( nn.Linear(self.decoder_input_size, self.bow_size), nn.LeakyReLU(), nn.Dropout(self.dropout), nn.Linear(self.bow_size, self.vocab_size) ) # self.layers["decoder"] = Decoder(embedder=self.layers["embedder"], input_size=self.embed_size, # hidden_size=self.hidden_size, # vocab_size=self.vocab_size, n_layers=1) self.layers["vae_decoder_pos"] = Decoder(embedder=self.layers["embedder"], input_size=self.embed_size, hidden_size=self.hidden_size, vocab_size=self.vocab_size, n_layers=1) self.layers["vae_decoder_neu"] = Decoder(embedder=self.layers["embedder"], input_size=self.embed_size, hidden_size=self.hidden_size, vocab_size=self.vocab_size, n_layers=1) self.layers["vae_decoder_neg"] = Decoder(embedder=self.layers["embedder"], input_size=self.embed_size, hidden_size=self.hidden_size, vocab_size=self.vocab_size, n_layers=1) self.layers["init_decoder"] = nn.Sequential( nn.Linear(self.decoder_input_size + self.context_dim, self.hidden_size), nn.BatchNorm1d(self.hidden_size, eps=1e-05, momentum=0.1), nn.LeakyReLU() ) self.layers["init_decoder_hidden"] = nn.Sequential( nn.Linear(self.decoder_input_size, self.hidden_size), nn.BatchNorm1d(self.hidden_size, eps=1e-05, momentum=0.1), nn.LeakyReLU() ) self.layers["init_decoder_hidden"].apply(self.init_weights) self.layers["bow_project_neg"].apply(self.init_weights) self.layers["bow_project_neu"].apply(self.init_weights) self.layers["bow_project_pos"].apply(self.init_weights) # self.optimizer_AE = optim.AdamW(list(self.layers["embedder"].parameters()) # + list(self.layers["seq_encoder"].parameters()) # + list(self.layers["vae_decoder_pos"].parameters()) # + list(self.layers["vae_decoder_neu"].parameters()) # + list(self.layers["vae_decoder_neg"].parameters()) # + list(self.layers["init_decoder"].parameters()), lr=self.lr_ae) self.optimizer_AE = { 'pos': optim.AdamW(list(self.layers["embedder"].parameters()) + list(self.layers["init_decoder"].parameters()) + list(self.layers["seq_encoder"].parameters()) + list(self.layers["vae_decoder_pos"].parameters()) + list(self.layers["pos_vae"].parameters()) + list(self.layers["bow_project_pos"].parameters()), lr=self.lr_vae), 'neu': optim.AdamW(list(self.layers["embedder"].parameters()) + list(self.layers["init_decoder"].parameters()) + list(self.layers["seq_encoder"].parameters()) + list(self.layers["vae_decoder_neu"].parameters()) + list(self.layers["neu_vae"].parameters()) + list(self.layers["bow_project_neu"].parameters()), lr=self.lr_vae), 'neg': optim.AdamW(list(self.layers["embedder"].parameters()) + list(self.layers["init_decoder"].parameters()) + list(self.layers["seq_encoder"].parameters()) + list(self.layers["vae_decoder_neg"].parameters()) + list(self.layers["neg_vae"].parameters()) + list(self.layers["bow_project_neg"].parameters()), lr=self.lr_vae) } self.optimizer_VAE = { 'pos': optim.AdamW(list(self.layers["vae_decoder_pos"].parameters()) + list(self.layers["pos_vae"].parameters()) + list(self.layers["bow_project_pos"].parameters()), lr=self.lr_vae), 'neu': optim.AdamW(list(self.layers["vae_decoder_neu"].parameters()) + list(self.layers["neu_vae"].parameters()) + list(self.layers["bow_project_neu"].parameters()), lr=self.lr_vae), 'neg': optim.AdamW(list(self.layers["vae_decoder_neg"].parameters()) + list(self.layers["neg_vae"].parameters()) + list(self.layers["bow_project_neg"].parameters()), lr=self.lr_vae) } # self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE, step_size=10, gamma=0.6) self.criterion_ce = nn.CrossEntropyLoss() self.softmax = nn.Softmax(dim=1) self.reconstruct_loss = dict()
def main(): parser = argparse.ArgumentParser() parser.add_argument("--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file).") parser.add_argument( "--eval_data_file", default=None, type=str, help= "An input evaluation data file to evaluate the perplexity on (a text file)." ) parser.add_argument("--checkpoint_dir", default=None, type=str, required=True, help="The directory where checkpoints are saved.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.") ## Variational auto-encoder parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.") parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.") parser.add_argument("--num_interpolation_steps", default=10, type=int, help="Total sentences to test recontruction.") parser.add_argument("--play_mode", default="interpolation", type=str, help="interpolation or reconstruction.") ## Encoder options parser.add_argument( "--encoder_model_type", default="bert", type=str, help="The encoder model architecture to be fine-tuned.") parser.add_argument( "--encoder_model_name_or_path", default="bert-base-cased", type=str, help="The encoder model checkpoint for weights initialization.") parser.add_argument( "--encoder_config_name", default="", type=str, help= "Optional pretrained config name or path if not the same as model_name_or_path" ) parser.add_argument( "--encoder_tokenizer_name", default="", type=str, help= "Optional pretrained tokenizer name or path if not the same as model_name_or_path" ) ## Decoder options parser.add_argument( "--decoder_model_type", default="gpt2", type=str, help="The decoder model architecture to be fine-tuned.") parser.add_argument( "--decoder_model_name_or_path", default="bert-base-cased", type=str, help="The decoder model checkpoint for weights initialization.") parser.add_argument( "--decoder_config_name", default="", type=str, help= "Optional pretrained config name or path if not the same as model_name_or_path" ) parser.add_argument( "--decoder_tokenizer_name", default="", type=str, help= "Optional pretrained tokenizer name or path if not the same as model_name_or_path" ) parser.add_argument("--per_gpu_train_batch_size", default=1, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument('--gloabl_step_eval', type=int, default=661, help="Evaluate the results at the given global step") parser.add_argument( "--max_seq_length", default=512, type=int, help= "Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length" ) # Interact with users parser.add_argument("--interact_with_user_input", action='store_true', help="Use user input to interact_with.") parser.add_argument("--sent_source", type=str, default="") parser.add_argument("--sent_target", type=str, default="") parser.add_argument("--sent_input", type=str, default="") parser.add_argument("--degree_to_target", type=float, default="1.0") ## Variational auto-encoder parser.add_argument("--nz", default=32, type=int, help="Latent space dimension.") parser.add_argument("--prompt", type=str, default="") parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--length", type=int, default=20) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_p", type=float, default=1.0) parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( "--block_size", default=-1, type=int, help="Optional input sequence length after tokenization." "The training dataset will be truncated in block of this size for training." "Default to the model max input length for single sentence inputs (take into account special tokens)." ) parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--use_philly", action='store_true', help="Use Philly for computing.") args = parser.parse_args() args.device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() set_seed(args) args.encoder_model_type = args.encoder_model_type.lower() args.decoder_model_type = args.decoder_model_type.lower() global_step = args.gloabl_step_eval output_encoder_dir = os.path.join( args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step)) output_decoder_dir = os.path.join( args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step)) checkpoints = [[output_encoder_dir, output_decoder_dir]] logger.info("Evaluate the following checkpoints: %s", checkpoints) # Load a trained Encoder model and vocabulary that you have fine-tuned encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[ args.encoder_model_type] model_encoder = encoder_model_class.from_pretrained( output_encoder_dir, latent_size=args.latent_size) tokenizer_encoder = encoder_tokenizer_class.from_pretrained( args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case) model_encoder.to(args.device) if args.block_size <= 0: args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence) # Load a trained Decoder model and vocabulary that you have fine-tuned decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[ args.decoder_model_type] model_decoder = decoder_model_class.from_pretrained( output_decoder_dir, latent_size=args.latent_size) tokenizer_decoder = decoder_tokenizer_class.from_pretrained( args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case) model_decoder.to(args.device) if args.block_size <= 0: args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence) # Load full model output_full_dir = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step)) checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'), map_location=torch.device('cpu')) # Chunyuan: Add Padding token to GPT2 special_tokens_dict = { 'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>' } num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict) print('We have added', num_added_toks, 'tokens to GPT2') model_decoder.resize_token_embeddings( len(tokenizer_decoder) ) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. assert tokenizer_decoder.pad_token == '<PAD>' # Evaluation model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args) model_vae.load_state_dict(checkpoint['model_state_dict']) logger.info("Pre-trained Optimus is successfully loaded") model_vae.to(args.device) if args.interact_with_user_input: if args.play_mode == 'interpolation': if len(args.sent_source) > 0 and len(args.sent_source) > 0: result = interpolate(model_vae, tokenizer_encoder, tokenizer_decoder, args) else: print('Please check: specify the source and target sentences!') if args.play_mode == 'analogy': if len(args.sent_source) > 0 and len(args.sent_source) > 0 and len( args.sent_input) > 0: result = analogy(model_vae, tokenizer_encoder, tokenizer_decoder, args) else: print( 'Please check: specify the source, target and input analogy sentences!' ) else: result = evaluate_latent_space(args, model_vae, tokenizer_encoder, tokenizer_decoder, prefix=global_step)
def main(args): class uniform_initializer(object): def __init__(self, stdv): self.stdv = stdv def __call__(self, tensor): nn.init.uniform_(tensor, -self.stdv, self.stdv) class xavier_normal_initializer(object): def __call__(self, tensor): nn.init.xavier_normal_(tensor) if args.cuda: print('using cuda') print(args) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, vocab=vocab) test_data = MonoTextData(args.test_data, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) device = torch.device("cuda" if args.cuda else "cpu") args.device = device encoder = LSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.optim == 'sgd': enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=1.0) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=1.0) opt_dict['lr'] = 1.0 else: enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001, betas=(0.9, 0.999)) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001, betas=(0.9, 0.999)) opt_dict['lr'] = 0.001 iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = -1 aggressive_flag = True if args.aggressive else False vae.train() start = time.time() kl_weight = args.kl_start anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) plot_data = train_data.data_sample(nsample=args.num_plot, device=device, batch_first=True) if args.plot_mode == 'multiple': grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1) plot_fn = plot_multiple elif args.plot_mode == 'single': grid_z = generate_grid(args.zmin, args.zmax, args.dz, device, ndim=1) plot_fn = plot_single posterior_mean = [] infer_mean = [] posterior_mean.append( vae.calc_model_posterior_mean(plot_data[0], grid_z)) infer_mean.append(vae.calc_infer_mean(plot_data[0])) train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) for epoch in range(args.epochs): report_kl_loss = report_rec_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): if args.plot_mode == "single": batch_data, _ = plot_data else: batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size # kl_weight = 1.0 kl_weight = min(1.0, kl_weight + anneal_rate) sub_iter = 1 batch_data_enc = batch_data burn_num_words = 0 burn_pre_loss = 1e4 burn_cur_loss = 0 while aggressive_flag and sub_iter < 100: enc_optimizer.zero_grad() dec_optimizer.zero_grad() burn_batch_size, burn_sents_len = batch_data_enc.size() burn_num_words += (burn_sents_len - 1) * burn_batch_size loss, loss_rc, loss_kl = vae.loss(batch_data_enc, kl_weight, nsamples=args.nsamples) burn_cur_loss += loss.sum().item() loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) enc_optimizer.step() if args.plot_mode == "single": batch_data_enc, _ = plot_data else: id_ = np.random.random_integers(0, len(train_data_batch) - 1) batch_data_enc = train_data_batch[id_] if sub_iter % 15 == 0: burn_cur_loss = burn_cur_loss / burn_num_words if burn_pre_loss - burn_cur_loss < 0: break burn_pre_loss = burn_cur_loss burn_cur_loss = burn_num_words = 0 sub_iter += 1 if args.plot_mode == 'single' and epoch == 0 and aggressive_flag: vae.eval() with torch.no_grad(): posterior_mean.append(posterior_mean[-1]) infer_mean.append(vae.calc_infer_mean(plot_data[0])) vae.train() enc_optimizer.zero_grad() dec_optimizer.zero_grad() loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not aggressive_flag: enc_optimizer.step() dec_optimizer.step() if args.plot_mode == 'single' and epoch == 0: vae.eval() with torch.no_grad(): posterior_mean.append( vae.calc_model_posterior_mean(plot_data[0], grid_z)) if aggressive_flag: infer_mean.append(infer_mean[-1]) else: infer_mean.append(vae.calc_infer_mean(plot_data[0])) vae.train() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() if iter_ % log_niter == 0: train_loss = (report_rec_loss + report_kl_loss) / report_num_sents if aggressive_flag or epoch == 0: vae.eval() mi = calc_mi(vae, val_data_batch) vae.train() print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi, report_rec_loss / report_num_sents, time.time() - start)) else: print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start)) sys.stdout.flush() report_rec_loss = report_kl_loss = 0 report_num_words = report_num_sents = 0 if iter_ % args.plot_niter == 0 and epoch == 0: vae.eval() with torch.no_grad(): if args.plot_mode == 'single' and iter_ != 0: plot_fn(infer_mean, posterior_mean, args) return elif args.plot_mode == "multiple": plot_fn(vae, plot_data, grid_z, iter_, args) vae.train() iter_ += 1 if aggressive_flag and (iter_ % len(train_data_batch)) == 0: vae.eval() cur_mi = calc_mi(vae, val_data_batch) vae.train() if cur_mi - pre_mi < 0: aggressive_flag = False print("STOP BURNING") pre_mi = cur_mi # return print('kl weight %.4f' % kl_weight) print('epoch: %d, VAL' % epoch) with torch.no_grad(): plot_fn(vae, plot_data, grid_z, iter_, args) vae.eval() with torch.no_grad(): loss, nll, kl, ppl = test(vae, val_data_batch, "VAL", args) if loss < best_loss: print('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= decay_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) print('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 if args.optim == 'sgd': enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"]) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"]) else: enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=opt_dict["lr"], betas=(0.5, 0.999)) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=opt_dict["lr"], betas=(0.5, 0.999)) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl = test(vae, test_data_batch, "TEST", args) vae.train() print('best_loss: %.4f, kl: %.4f, nll: %.4f, ppl: %.4f' \ % (best_loss, best_kl, best_nll, best_ppl)) sys.stdout.flush() # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): calc_iwnll(vae, test_data_batch, args)
def model_summary(model_type, img_res, hidden_size, enc_type, dec_type, loss, batch_size, device=torch.device("cuda:1"), verbose=True): pattern = re.compile(r"Params size \(MB\):(.*)\n") pattern2 = re.compile(r"Forward/backward pass size \(MB\):(.*)\n") input_dim = 3 enc_input_size = (input_dim, img_res, img_res) dec_input_size = (hidden_size, img_res // 4, img_res // 4) pdb.set_trace() if verbose: print(f"model:{model_type}") print(f"depth:{enc_type}_{dec_type}") if model_type == "acai": model = ACAI(img_res, input_dim, hidden_size, enc_type, dec_type).to(device) elif model_type == "vqvae": model = VectorQuantizedVAE(input_dim, hidden_size, enc_type=enc_type, dec_type=dec_type).to(device) elif model_type == "vae": model = VAE(input_dim, hidden_size, enc_type=enc_type, dec_type=dec_type).to(device) encoder_summary, _ = torchsummary.summary_string(model.encoder, enc_input_size, device=device, batch_size=batch_size) decoder_summary, _ = torchsummary.summary_string(model.decoder, dec_input_size, device=device, batch_size=batch_size) if verbose: print(encoder_summary) print(decoder_summary) discriminators = {} if model_type == "acai": disc = Discriminator(input_dim, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["interp_disc"] = (disc_param_size, disc_forward_size) if loss == "gan": disc = Discriminator(input_dim, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["recons_disc"] = (disc_param_size, 2 * disc_forward_size) elif loss == "comp": disc = AnchorComparator(input_dim * 2, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["recons_disc"] = (disc_param_size, 2 * disc_forward_size) elif "comp_2" in loss: disc = ClubbedPermutationComparator(input_dim * 2, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["recons_disc"] = (disc_param_size, 2 * disc_forward_size) elif "comp_6" in loss: disc = FullPermutationComparator(input_dim * 2, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["recons_disc"] = (disc_param_size, 2 * disc_forward_size) encoder_param_size = float(re.search(pattern, encoder_summary).group(1)) encoder_forward_size = float(re.search(pattern2, encoder_summary).group(1)) decoder_param_size = float(re.search(pattern, decoder_summary).group(1)) decoder_forward_size = float(re.search(pattern2, decoder_summary).group(1)) if verbose: if "ACAI" in str(type(model)): print( f"discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}" ) if loss == "gan": print( f"reconstruction discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}" ) print( f"encoder:\n\tparams:{encoder_param_size}\n\tforward:{encoder_forward_size}" ) print( f"decoder:\n\tparams:{decoder_param_size}\n\tforward:{decoder_forward_size}" ) encoder = {"params": encoder_param_size, "forward": encoder_forward_size} decoder = {"params": decoder_param_size, "forward": decoder_forward_size} return encoder, decoder, discriminators
def main(args): global logging debug = (args.reconstruct_from != "" or args.eval == True) # don't make exp dir for reconstruction logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=debug) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.load_path: loaded_state_dict = torch.load(args.load_path) #curr_state_dict = vae.state_dict() #curr_state_dict.update(loaded_state_dict) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) if args.reset_dec: vae.decoder.reset_parameters(model_init, emb_init) if args.eval: logging('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) logging("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) nll, ppl = calc_iwnll(vae, test_data_batch, args) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) return if args.reconstruct_from != "": print("begin decoding") sys.stdout.flush() vae.load_state_dict(torch.load(args.reconstruct_from)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) # test(vae, test_data_batch, "TEST", args) reconstruct(vae, test_data_batch, vocab, args.decoding_strategy, args.reconstruct_to) return if args.opt == "sgd": enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) opt_dict['lr'] = 0.001 else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = 0 vae.train() start = time.time() train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) # At any point you can hit Ctrl + C to break out of training early. try: for epoch in range(args.epochs): report_kl_loss = report_rec_loss = report_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size kl_weight = args.beta enc_optimizer.zero_grad() dec_optimizer.zero_grad() if args.iw_train_nsamples < 0: loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) else: loss, loss_rc, loss_kl = vae.loss_iw( batch_data, kl_weight, nsamples=args.iw_train_nsamples, ns=ns) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() report_loss += loss.item() * batch_size if iter_ % log_niter == 0: #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs, kl_weight %.4f' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start, kl_weight)) #sys.stdout.flush() report_rec_loss = report_kl_loss = report_loss = 0 report_num_words = report_num_sents = 0 iter_ += 1 logging('kl weight %.4f' % kl_weight) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) au, au_var = calc_au(vae, val_data_batch) logging("%d active units" % au) # print(au_var) if args.save_ckpt > 0 and epoch <= args.save_ckpt: logging('save checkpoint') torch.save( vae.state_dict(), os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt')) if loss < best_loss: logging('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) if args.save_latent > 0 and epoch <= args.save_latent: visualize_latent(args, epoch, vae, "cuda", test_data) vae.train() except KeyboardInterrupt: logging('-' * 100) logging('Exiting from training early') # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) logging("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): nll, ppl = calc_iwnll(vae, test_data_batch, args) logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl))
def main(args): train_loader, val_loader, test_loader = train_util.get_dataloaders(args) input_dim = 3 model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type) opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True) discriminators = {} if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim*2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] if torch.cuda.device_count() > 1: model = train_util.ae_data_parallel(model) for disc in discriminators: discriminators[disc][0] = torch.nn.DataParallel(discriminators[disc][0]) model.to(args.device) model_name = f"vae_{args.recons_loss}" if args.output_folder is None: args.output_folder = os.path.join(model_name, args.dataset, f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}") log_save_path = os.path.join("./logs", args.output_folder) model_save_path = os.path.join("./models", args.output_folder) if not os.path.exists(log_save_path): os.makedirs(log_save_path) print(f"log:{log_save_path}", file=sys.stderr) sys.stderr.flush() if not os.path.exists(model_save_path): os.makedirs(model_save_path) writer = SummaryWriter(log_save_path) print(f"train loader length:{len(train_loader)}", file=sys.stderr) best_loss = torch.tensor(np.inf) if args.weights == "load": start_epoch = train_util.load_state(model_save_path, model, opt, discriminators) else: start_epoch = 0 recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.save_recons_img_grid("val", recons_input_img, model, 0, args) for epoch in range(1, args.num_epochs): print("Epoch {}:".format(epoch)) train(model, opt, train_loader) curr_loss = val(model, val_loader) # val_loss_dict, z = train_util.test(get_losses, model, val_loader, args, discriminators) print(f"epoch val loss:{curr_loss}", file=sys.stderr) sys.stderr.flush() train_util.save_recons_img_grid("val", recons_input_img, model, epoch+1, args) train_util.save_interp_img_grid("val", recons_input_img, model, epoch+1, args)
train=True, download=True, transform=preproc_transform, ), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)( '../data/{}/'.format(DATASET), train=False, transform=preproc_transform), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) model = VAE(INPUT_DIM, DIM, Z_DIM).cuda() print(model) opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True) def train(): train_loss = [] model.train() for batch_idx, (x, _) in enumerate(train_loader): start_time = time.time() x = x.cuda() x_tilde, kl_d = model(x) loss_recons = F.mse_loss(x_tilde, x, size_average=False) / x.size(0) loss = loss_recons + kl_d
def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument("--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file).") parser.add_argument("--checkpoint_dir", default=None, type=str, help="The directory where checkpoints are saved.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument("--dataset", default=None, type=str, help="The dataset.") ## Other parameters parser.add_argument( "--eval_data_file", default=None, type=str, help= "An optional input evaluation data file to evaluate the perplexity on (a text file)." ) parser.add_argument("--ExpName", default="", type=str, help="The experiment name used in Azure Table.") parser.add_argument("--save_bert_gpt_init", action='store_true', help="Use Philly for computing.") parser.add_argument( "--length_weighted_loss", action='store_true', help="Use sentence length re-weight the reconstruction loss.") ## Decoder options parser.add_argument( "--decoder_model_type", default="gpt2", type=str, help="The decoder model architecture to be fine-tuned.") parser.add_argument( "--decoder_model_name_or_path", default="bert-base-cased", type=str, help="The decoder model checkpoint for weights initialization.") parser.add_argument( "--decoder_config_name", default="", type=str, help= "Optional pretrained config name or path if not the same as model_name_or_path" ) parser.add_argument( "--decoder_tokenizer_name", default="", type=str, help= "Optional pretrained tokenizer name or path if not the same as model_name_or_path" ) ## Variational auto-encoder parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.") parser.add_argument( "--use_deterministic_connect", action='store_true', help= "Use deterministic inference to generate latent codes, i.e., standard auto-encoders." ) parser.add_argument( "--use_pretrained_model", action='store_true', help="Use pre-trained auto-encoder models as the initialization") parser.add_argument("--latent_as_gpt_memory", default=1, type=int, help="Latent vector as memery for GPT2 to attend.") parser.add_argument("--latent_as_gpt_emb", default=1, type=int, help="Latent vector as embeddings for GPT2.") ## Objective functions parser.add_argument( "--mlm", action='store_true', help= "Train with masked-language modeling loss instead of language modeling." ) parser.add_argument( "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss") parser.add_argument( "--beta", type=float, default=1.0, help="The weighting hyper-parameter of the KL term in VAE") parser.add_argument( "--cache_dir", default="", type=str, help= "Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)" ) parser.add_argument( "--max_seq_length", default=512, type=int, help= "Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length" ) parser.add_argument( "--block_size", default=-1, type=int, help="Optional input sequence length after tokenization." "The training dataset will be truncated in block of this size for training." "Default to the model max input length for single sentence inputs (take into account special tokens)." ) parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_eval_rec", action='store_true', help="Whether to run eval reconstruction on a set of models.") parser.add_argument( "--evaluate_during_training", action='store_true', help="Run evaluation during training at each logging step.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") # Training Schedules parser.add_argument( "--ratio_increase", default=0.25, type=float, help="Learning schedule, the percentage for the annealing stage.") parser.add_argument( "--ratio_zero", default=0.25, type=float, help= "Learning schedule, the percentage for the pure auto-encoding stage.") parser.add_argument("--fb_mode", default=0, type=int, help="free bit training mode.") parser.add_argument("--dim_target_kl", default=3.0, type=float, help="dim_target_kl free bit training mode.") parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--max_steps", default=-1, type=int, help= "If > 0: set total number of training steps to perform. Override num_train_epochs." ) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--use_philly", action='store_true', help="Use Philly for computing.") parser.add_argument( "--use_pretrained_vae", action='store_true', help= "Use use_pretrained_vae as initialization, where beta value is specified in the folder" ) parser.add_argument("--use_random_weight", action='store_true', help="Use random weights as initialization") ## IO: Logging and Saving parser.add_argument('--logging_steps', type=int, default=50, help="Log every X updates steps.") parser.add_argument('--save_steps', type=int, default=50, help="Save checkpoint every X updates steps.") parser.add_argument( "--eval_all_checkpoints", action='store_true', help= "Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number" ) parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--overwrite_output_dir', action='store_true', help="Overwrite the content of the output directory") parser.add_argument( '--overwrite_cache', action='store_true', help="Overwrite the cached training and evaluation sets") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gloabl_step_eval', type=int, default=661, help="Evaluate the results at the given global step") # Precision & Distributed Training parser.add_argument( '--fp16', action='store_true', help= "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit" ) parser.add_argument( '--fp16_opt_level', type=str, default='O1', help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") args = parser.parse_args() if args.decoder_model_type in ["bert", "roberta"] and not args.mlm: raise ValueError( "BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " "flag (masked language modeling).") if args.eval_data_file is None and args.do_eval: raise ValueError( "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " "or remove the --do_eval argument.") if os.path.exists(args.output_dir) and os.listdir( args.output_dir ) and args.do_train and not args.overwrite_output_dir: raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome." .format(args.output_dir)) # # Setup distant debugging if needed # if args.server_ip and args.server_port: # # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # import ptvsd # print("Waiting for debugger attach") # ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) # ptvsd.wait_for_attach() # Setup CUDA, GPU & distributed training if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl') args.n_gpu = 1 args.device = device # Setup logging logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str( args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str( args.dim_target_kl) + '_Ra_' + str( args.ratio_increase) + '_R0_' + str(args.ratio_zero) table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size) try: ts.create_table(table_name) except: print("pass") pass # Set seed set_seed(args) if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Barrier to make sure only the first process in distributed training download model & vocab # Load Optimius pre-trained model and tokenizer if args.use_pretrained_model: args.decoder_model_type = args.decoder_model_type.lower() global_step = args.gloabl_step_eval output_decoder_dir = os.path.join( args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step)) output_full_dir = os.path.join( args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step)) checkpoints = [[output_decoder_dir]] logger.info("Evaluate the following checkpoints: %s", checkpoints) # Load a trained Encoder model and vocabulary # Load a trained Decoder model and vocabulary decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[ args.decoder_model_type] model_decoder = decoder_model_class.from_pretrained( output_decoder_dir, latent_size=args.latent_size) tokenizer_decoder = decoder_tokenizer_class.from_pretrained( args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case) model_decoder.to(args.device) if args.block_size <= 0: args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model # print("block size ", args.block_size) args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence) print("block size: ", args.block_size) # Load full model checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin')) else: # Load BERT and GPT weights (As an alternaive, one may train a VAE for this small) ## Decoder decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[ args.decoder_model_type] decoder_config = decoder_config_class.from_pretrained( args.decoder_config_name if args.decoder_config_name else args. decoder_model_name_or_path) tokenizer_decoder = decoder_tokenizer_class.from_pretrained( args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case) if args.block_size <= 0: args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence) if args.latent_as_gpt_emb + args.latent_as_gpt_memory == 0: return # latent vector should pass into GPT to decode else: latent_as_gpt_emb = True if args.latent_as_gpt_emb == 1 else False latent_as_gpt_memory = True if args.latent_as_gpt_memory == 1 else False setattr(decoder_config, "latent_size", args.latent_size) model_decoder = decoder_model_class.from_pretrained( args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size, latent_as_gpt_emb=latent_as_gpt_emb, latent_as_gpt_memory=latent_as_gpt_memory) # Save the init weights of BERT and GPT-2, so that we can load from local (Some infra requires so) if args.save_bert_gpt_init: decoder_path = os.path.join( args.output_dir, f"initial-models-tokenization-decoder-{args.latent_size}") if not os.path.exists(decoder_path): os.makedirs(decoder_path) model_decoder.save_pretrained(decoder_path) tokenizer_decoder.save_pretrained(decoder_path) return # Chunyuan: Add Padding token to GPT2 special_tokens_dict = { 'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>' } num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict) print('We have added', num_added_toks, 'tokens to GPT2') model_decoder.resize_token_embeddings( len(tokenizer_decoder) ) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. assert tokenizer_decoder.pad_token == '<PAD>' # model_decoder.to(args.device) model_vae = AE(model_decoder, tokenizer_decoder, args) # pdb.set_trace() if args.use_random_weight: print("---" * 20) print("random weight init") print("---" * 20) model_vae.apply(weights_init_rondom) if args.use_pretrained_model: pre_model = checkpoint['model_state_dict'] model_dict = model_vae.state_dict() pre_dict = {k: v for k, v in pre_model.items() if k in model_dict} model_dict.update(pre_dict) model_vae.load_state_dict(model_dict) # model_vae.load_state_dict(checkpoint['model_state_dict']) logger.info("Pre-trained Optimus is successfully loaded") model_vae.to(args.device) # exit() # on_gpu = next(model_vae.parameters()).is_cuda if args.local_rank == 0: torch.distributed.barrier( ) # End of barrier to make sure only the first process in distributed training download model & vocab logger.info("Training/evaluation parameters %s", args) ############################## # Training global_step = 0 if args.do_train: if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache print("===" * 20) train_dataloader = build_dataload_and_cache_examples( args, [tokenizer_decoder], evaluate=False) print("===" * 20) if args.local_rank == 0: torch.distributed.barrier() # print("+++"*20, "training", "+++"*20) global_step, tr_loss, optimizer = train(args, train_dataloader, model_vae, tokenizer_decoder, table_name) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained() if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): save_checkpoint(model_vae, optimizer, global_step, args) ############################## # Evaluation the metrics of VAE models, including PPL, MI, AU results = {} if args.do_eval and args.local_rank in [-1, 0]: if global_step == 0: global_step = args.gloabl_step_eval output_decoder_dir = os.path.join( args.output_dir, 'checkpoint-decoder-{}'.format(global_step)) output_full_dir = os.path.join( args.output_dir, 'checkpoint-full-{}'.format(global_step)) checkpoint_dir = [output_decoder_dir, output_full_dir] logger.info("Evaluate the following checkpoint: %s", checkpoint_dir[-1]) global_step = checkpoint_dir[-1].split( '-')[-1] if len(checkpoint_dir) > 1 else "" checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin')) model_vae.load_state_dict(checkpoint['model_state_dict']) logger.info( f"Pre-trained Optimus is successfully loaded: {output_full_dir}") model_vae.to(args.device) result = evaluate(args, model_vae, tokenizer_decoder, table_name, prefix=global_step, subset='test') result = dict( (k + '_{}'.format(global_step), v) for k, v in result.items()) results.update(result) output_eval_file = os.path.join(args.output_dir, "eval_vae_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") for key in sorted(results.keys()): logger.info("%s = %s", key, str(results[key])) writer.write("%s = %s\n" % (key, str(results[key]))) logger.info( f"The testing results are successfully saved: {output_eval_file}") ############################## # Evaluate the reconstruction loss for each checkpoints; # This is used in studying two different latent vector injection schemes results = {} if args.do_eval_rec and args.local_rank in [-1, 0]: if global_step == 0: global_step = args.gloabl_step_eval # eval_steps = range(500, 13500, 500) # eval_steps = range(1000, 2000, 500) eval_steps = range(2000, 32000, 2000) checkpoints = [] for e in eval_steps: output_decoder_dir = os.path.join( args.output_dir, 'checkpoint-decoder-{}'.format(e)) checkpoints.append([output_decoder_dir]) logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: global_step = checkpoint[0].split( '-')[-1] if len(checkpoints) > 1 else "" model_decoder = decoder_model_class.from_pretrained(checkpoint[1]) model_decoder.to(args.device) model_vae = VAE(model_decoder, tokenizer_decoder, args).to(args.device) result = evaluate_rec(args, model_vae, tokenizer_decoder, table_name, prefix=global_step, subset='test') result = dict((k + '_test_{}'.format(global_step), v) for k, v in result.items()) results.update(result) result = evaluate_rec(args, model_vae, tokenizer_decoder, table_name, prefix=global_step, subset='train') result = dict((k + '_train_{}'.format(global_step), v) for k, v in result.items()) results.update(result) # pdb.set_trace() output_eval_file = os.path.join(args.output_dir, "eval_rec_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") for key in sorted(results.keys()): logger.info("%s = %s", key, str(results[key])) writer.write("%s = %s\n" % (key, str(results[key]))) logger.info( f"The testing results are successfully saved: {output_eval_file}") return results
def main(args): input_dim = 3 model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type) opt = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-6) # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators = {} if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] if torch.cuda.device_count() > 1: model = train_util.ae_data_parallel(model) for disc in discriminators: discriminators[disc][0] = torch.nn.DataParallel( discriminators[disc][0]) model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) print("model built", file=sys.stderr) #print("model created") train_loader, val_loader, test_loader = train_util.get_dataloaders(args) print("loaders acquired", file=sys.stderr) #print("loaders acquired") model_name = f"vae_{args.recons_loss}" if args.output_folder is None: args.output_folder = os.path.join( model_name, args.dataset, f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}" ) log_save_path = os.path.join("./logs", args.output_folder) model_save_path = os.path.join("./models", args.output_folder) if not os.path.exists(log_save_path): os.makedirs(log_save_path) print(f"log:{log_save_path}", file=sys.stderr) sys.stderr.flush() if not os.path.exists(model_save_path): os.makedirs(model_save_path) writer = SummaryWriter(log_save_path) print(f"train loader length:{len(train_loader)}", file=sys.stderr) best_loss = torch.tensor(np.inf) if args.weights == "load": start_epoch = train_util.load_state(model_save_path, model, opt, discriminators) else: start_epoch = 0 recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) stop_patience = args.stop_patience for epoch in range(start_epoch, args.num_epochs): try: train(model, train_loader, opt, epoch, writer, args, discriminators) except RuntimeError as err: print("".join( traceback.TracebackException.from_exception(err).format()), file=sys.stderr) print("*******", file=sys.stderr) print(err, file=sys.stderr) exit(0) val_loss_dict, z = train_util.test(get_losses, model, val_loader, args, discriminators) print(f"epoch loss:{val_loss_dict['recons_loss'].item()}") train_util.save_recons_img_grid("test", recons_input_img, model, epoch + 1, args) train_util.save_interp_img_grid("test", recons_input_img, model, epoch + 1, args) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, opt, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, model_save_path)
def main(args): writer = SummaryWriter('./logs_vae/{0}'.format(args.output_folder)) save_filename = './models_vae/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VAE(num_channels, args.hidden_size, args.z).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. for epoch in range(args.num_epochs): train(epoch, train_loader, model, optimizer, args, writer) loss = test(valid_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)
def main(args, args_model): global logging logging = get_logger_existing_dir(os.path.dirname(args.load_path), 'log_classifier.txt') if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} if getattr(args, 'vocab_file', None) is not None: with open(args.vocab_file) as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) filename_glob = args.train_data + '.seed_*.n_' + str( args.num_label_per_class) train_sets = glob.glob(filename_glob) print("Train sets:", train_sets) main_train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab = main_train_data.vocab vocab_size = len(vocab) logging('finish reading datasets, vocab size is %d' % len(vocab)) #sys.stdout.flush() model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args_model.device = device if args_model.enc_type == 'lstm': args_model.pooling = getattr(args_model, 'pooling', None) encoder = GaussianLSTMEncoder( args_model, vocab_size, model_init, emb_init, pooling=args_model.pooling, ) elif args_model.enc_type in ['max_avg_pool', 'max_pool', 'avg_pool']: args_model.skip_first_word = getattr(args_model, 'skip_first_word', None) encoder = GaussianPoolEncoder( args_model, vocab_size, model_init, emb_init, enc_type=args_model.enc_type, skip_first_word=args_model.skip_first_word) #args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") args_model.encode_length = getattr(args_model, 'encode_length', None) if args_model.dec_type == 'lstm': decoder = LSTMDecoder(args_model, vocab, model_init, emb_init, args_model.encode_length) elif args_model.dec_type == 'unigram': decoder = UnigramDecoder(args_model, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args_model, args_model.encode_length).to(device) if args.load_path: print("load args!") print(vae) loaded_state_dict = torch.load(args.load_path) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) vae.eval() def preprocess(data_fn): codes, labels = read_dataset(data_fn, vocab, device, vae, args.classify_using_samples) if args.classify_using_samples: is_gaussian_enc = codes.shape[1] == (vae.encoder.nz * 2) codes = augment_dataset(codes, 1, is_gaussian_enc, vae) # use only 1 sample for test codes = codes.cpu().numpy() labels = labels.cpu().numpy() return codes, labels test_codes, test_labels = preprocess(args.test_data) test_f1_scores = [] average_f1 = 'macro' f1_scorer = make_scorer(f1_score, average=average_f1, labels=np.unique(test_labels), greater_is_better=True) # log loss: negative log likelihood. We should minimize that, so greater_is_better=False log_loss_scorer = make_scorer(log_loss, needs_proba=True, greater_is_better=False) warnings.filterwarnings('ignore') results = { 'n_samples_per_class': args.num_label_per_class, } n_repeats = args.n_repeats n_splits = min(args.num_label_per_class, 5) for i, fn in enumerate(train_sets): codes, labels = preprocess(fn) if args.resample > 1: # going to augment the training set by sampling # then create a new cross validation function to get the correct indices cross_val = augment_cross_val(labels, args.resample, n_splits, n_repeats) labels = np.repeat(labels, args.resample) else: cross_val = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats) scaler = StandardScaler() codes = scaler.fit_transform(codes) scaled_test_codes = scaler.transform(test_codes) gridsearch = GridSearchCV( LogisticRegression(solver='sag', multi_class='auto'), { "penalty": ['l2'], "C": [0.01, 0.1, 1, 10, 100], }, cv=cross_val, scoring={ "f1": f1_scorer, "log": log_loss_scorer, }, refit=False, ) clf = gridsearch clf.fit(codes, labels) crossval_f1, test_f1 = refit_and_eval( 'f1', clf, clf.cv_results_, codes, labels, scaled_test_codes, test_labels, f1_scorer, ) crossval_log, test_log_loss = refit_and_eval( 'log', clf, clf.cv_results_, codes, labels, scaled_test_codes, test_labels, log_loss_scorer, ) results[i] = { "F1": { 'crossval': crossval_f1, 'test': test_f1 }, "log": { 'crossval': crossval_log, 'test': test_log_loss }, } print(results[i]) if args.classify_using_samples: n_per_class = str(args.num_label_per_class) resample = 1 if args.resample == -1 else args.resample output_fn = os.path.join( args.exp_dir, 'results_sample_' + str(resample) + '_' + n_per_class + '.json') else: output_fn = os.path.join(args.exp_dir, 'results_' + n_per_class + '.json') with open(output_fn, 'w') as f: json.dump(results, f)
def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument( "--input_file_path", default=None, type=str, required=True, help="The output directory where the input files will be written.") parser.add_argument( "--output_file_path", default=None, type=str, required=True, help="The output directory where the output files will be written.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help="The output directory where the logs and results will be saved.") parser.add_argument("--dataset", default=None, type=str, help="The dataset.") ## Other parameters parser.add_argument("--ExpName", default="", type=str, help="The experiment name used in Azure Table.") ## Encoder options parser.add_argument( "--encoder_model_type", default="bert", type=str, help="The encoder model architecture to be fine-tuned.") parser.add_argument( "--encoder_model_name_or_path", default="bert-base-cased", type=str, help="The encoder model checkpoint for weights initialization.") parser.add_argument( "--encoder_config_name", default="", type=str, help= "Optional pretrained config name or path if not the same as model_name_or_path" ) parser.add_argument( "--encoder_tokenizer_name", default="", type=str, help= "Optional pretrained tokenizer name or path if not the same as model_name_or_path" ) ## Decoder options parser.add_argument( "--decoder_model_type", default="gpt2", type=str, help="The decoder model architecture to be fine-tuned.") parser.add_argument( "--decoder_model_name_or_path", default="bert-base-cased", type=str, help="The decoder model checkpoint for weights initialization.") parser.add_argument( "--decoder_config_name", default="", type=str, help= "Optional pretrained config name or path if not the same as model_name_or_path" ) parser.add_argument( "--decoder_tokenizer_name", default="", type=str, help= "Optional pretrained tokenizer name or path if not the same as model_name_or_path" ) ## Variational auto-encoder parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.") parser.add_argument( "--use_deterministic_connect", action='store_true', help= "Use deterministic inference to generate latent codes, i.e., standard auto-encoders." ) ## Objective functions parser.add_argument( "--mlm", action='store_true', help= "Train with masked-language modeling loss instead of language modeling." ) parser.add_argument( "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss") parser.add_argument( "--beta", type=float, default=1.0, help="The weighting hyper-parameter of the KL term in VAE") parser.add_argument( "--cache_dir", default="", type=str, help= "Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)" ) parser.add_argument( "--max_seq_length", default=512, type=int, help= "Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length" ) parser.add_argument( "--block_size", default=-1, type=int, help="Optional input sequence length after tokenization." "The training dataset will be truncated in block of this size for training." "Default to the model max input length for single sentence inputs (take into account special tokens)." ) parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--evaluate_during_training", action='store_true', help="Run evaluation during training at each logging step.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") # Training Schedules parser.add_argument( "--ratio_increase", default=0.25, type=float, help="Learning schedule, the percentage for the annealing stage.") parser.add_argument( "--ratio_zero", default=0.25, type=float, help= "Learning schedule, the percentage for the pure auto-encoding stage.") parser.add_argument("--fb_mode", default=0, type=int, help="free bit training mode.") parser.add_argument("--dim_target_kl", default=3.0, type=float, help="dim_target_kl free bit training mode.") parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--max_steps", default=-1, type=int, help= "If > 0: set total number of training steps to perform. Override num_train_epochs." ) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--use_philly", action='store_true', help="Use Philly for computing.") ## IO: Logging and Saving parser.add_argument('--logging_steps', type=int, default=50, help="Log every X updates steps.") parser.add_argument('--save_steps', type=int, default=50, help="Save checkpoint every X updates steps.") parser.add_argument( "--eval_all_checkpoints", action='store_true', help= "Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number" ) parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--overwrite_output_dir', action='store_true', help="Overwrite the content of the output directory") parser.add_argument( '--overwrite_cache', action='store_true', help="Overwrite the cached training and evaluation sets") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gloabl_step_eval', type=int, default=661, help="Evaluate the results at the given global step") # Precision & Distributed Training parser.add_argument( '--fp16', action='store_true', help= "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit" ) parser.add_argument( '--fp16_opt_level', type=str, default='O1', help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") args = parser.parse_args() if args.decoder_model_type in ["bert", "roberta"] and not args.mlm: raise ValueError( "BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " "flag (masked language modeling).") if os.path.exists(args.output_file_path) and os.listdir( args.output_file_path ) and args.do_train and not args.overwrite_output_dir: raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome." .format(args.output_file_path)) # Setup distant debugging if needed if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() # Setup CUDA, GPU & distributed training logger.info(f'Local rank is {args.local_rank}') if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl') args.n_gpu = 1 args.device = device # Setup logging logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str( args.latent_size) + '_Beta_' + str(args.beta) + '_Dkl_' + str( args.dim_target_kl) + '_Ra_' + str( args.ratio_increase) + '_R0_' + str(args.ratio_zero) table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size) try: ts.create_table(table_name) except: pass # Set seed set_seed(args) # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Barrier to make sure only the first process in distributed training download model & vocab ## Encoder encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[ args.encoder_model_type] encoder_config = encoder_config_class.from_pretrained( args.encoder_config_name if args.encoder_config_name else args. encoder_model_name_or_path) tokenizer_encoder = encoder_tokenizer_class.from_pretrained( args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case) if args.block_size <= 0: args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence) model_encoder = encoder_model_class.from_pretrained( args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size) # model_encoder.to(args.device) ## Decoder decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[ args.decoder_model_type] decoder_config = decoder_config_class.from_pretrained( args.decoder_config_name if args.decoder_config_name else args. decoder_model_name_or_path) tokenizer_decoder = decoder_tokenizer_class.from_pretrained( args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case) if args.block_size <= 0: args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence) model_decoder = decoder_model_class.from_pretrained( args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size) # Chunyuan: Add Padding token to GPT2 special_tokens_dict = { 'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>' } num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict) print('We have added', num_added_toks, 'tokens to GPT2') model_decoder.resize_token_embeddings( len(tokenizer_decoder) ) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. assert tokenizer_decoder.pad_token == '<PAD>' # model_decoder.to(args.device) model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) # # on_gpu = next(model_vae.parameters()).is_cuda if args.local_rank == 0: torch.distributed.barrier( ) # End of barrier to make sure only the first process in distributed training download model & vocab logger.info("Training/evaluation parameters %s", args) global_step = 0 # Training if args.do_train: if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache train_dataloader = build_dataload_and_cache_examples( args, [tokenizer_encoder, tokenizer_decoder], evaluate=False) if args.local_rank == 0: torch.distributed.barrier() num_collected, num_dropped = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name) logger.info(" num_collected = %s, num_dropped = %s", num_collected, num_dropped)
def main(args): train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) vocab_path = os.path.join("/".join(args.train_data.split("/")[:-1]), "vocab.txt") with open(vocab_path, "w") as fout: for i in range(vocab_size): fout.write("{}\n".format(vocab.id2word(i))) #return val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() log_niter = (len(train_data)//args.batch_size)//10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) print('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch, test_batch_labels = test_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) # test(vae, test_data_batch, "TEST", args) # au, au_var = calc_au(vae, test_data_batch) # print("%d active units" % au) train_data_batch, train_batch_labels = train_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch, val_batch_labels = val_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) print("getting vectors for training") print(args.save_dir) save_latents(args, vae, train_data_batch, train_batch_labels, "train") print("getting vectors for validating") save_latents(args, vae, val_data_batch, val_batch_labels, "val") print("getting vectors for testing") save_latents(args, vae, test_data_batch, test_batch_labels, "test")
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) print('train_loader', len(train_loader)) print('test_loader', len(test_loader)) #%% model = VAE().to(device) # optimizer = optim.Adam(model.parameters(), lr=1e-3) optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.0005) #%% epochs = 100 viz = Visdom() global plotter, recon plotter = utils.VisdomLinePlotter(env_name='main') sample_image = utils.VisdomImage(env_name='main') recon = utils.VisdomImage(env_name='main') for epoch in range(1, epochs + 1):
def main(args): if args.save_path == '': make_savepath(args) seed(args) if args.cuda: print('using cuda') print(args) device = torch.device("cuda" if args.cuda else "cpu") args.device = device opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} all_data = torch.load(args.data_file) x_train, x_val, x_test = all_data x_train = x_train.to(device) x_val = x_val.to(device) x_test = x_test.to(device) y_size = 1 y_train = x_train.new_zeros(x_train.size(0), y_size) y_val = x_train.new_zeros(x_val.size(0), y_size) y_test = x_train.new_zeros(x_test.size(0), y_size) print(torch.__version__) train_data = torch.utils.data.TensorDataset(x_train, y_train) val_data = torch.utils.data.TensorDataset(x_val, y_val) test_data = torch.utils.data.TensorDataset(x_test, y_test) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True) print('Train data: %d batches' % len(train_loader)) print('Val data: %d batches' % len(val_loader)) print('Test data: %d batches' % len(test_loader)) sys.stdout.flush() log_niter = len(train_loader) // 5 encoder = ResNetEncoderV2(args) decoder = PixelCNNDecoderV2(args) vae = VAE(encoder, decoder, args).to(device) if args.sample_from != '': save_dir = "samples/%s" % args.dataset if not os.path.exists(save_dir): os.makedirs(save_dir) vae.load_state_dict(torch.load(args.sample_from)) vae.eval() with torch.no_grad(): sample_z = vae.sample_from_prior(400).to(device) sample_x, sample_probs = vae.decode(sample_z, False) image_file = 'sample_binary_from_%s.png' % ( args.sample_from.split('/')[-1][:-3]) save_image(sample_x.data.cpu(), os.path.join(save_dir, image_file), nrow=20) image_file = 'sample_cont_from_%s.png' % ( args.sample_from.split('/')[-1][:-3]) save_image(sample_probs.data.cpu(), os.path.join(save_dir, image_file), nrow=20) return if args.eval: print('begin evaluation') test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True) vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test(vae, test_loader, "TEST", args) au, au_var = calc_au(vae, test_loader) print("%d active units" % au) # print(au_var) calc_iwnll(vae, test_loader, args) return enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) opt_dict['lr'] = 0.001 iter_ = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 decay_cnt = pre_mi = best_mi = mi_not_improved = 0 aggressive_flag = True if args.aggressive else False vae.train() start = time.time() kl_weight = args.kl_start anneal_rate = (1.0 - args.kl_start) / (args.warm_up * len(train_loader)) for epoch in range(args.epochs): report_kl_loss = report_rec_loss = 0 report_num_examples = 0 for datum in train_loader: batch_data, _ = datum batch_data = torch.bernoulli(batch_data) batch_size = batch_data.size(0) report_num_examples += batch_size # kl_weight = 1.0 kl_weight = min(1.0, kl_weight + anneal_rate) sub_iter = 1 batch_data_enc = batch_data burn_num_examples = 0 burn_pre_loss = 1e4 burn_cur_loss = 0 while aggressive_flag and sub_iter < 100: enc_optimizer.zero_grad() dec_optimizer.zero_grad() burn_num_examples += batch_data_enc.size(0) loss, loss_rc, loss_kl = vae.loss(batch_data_enc, kl_weight, nsamples=args.nsamples) burn_cur_loss += loss.sum().item() loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) enc_optimizer.step() id_ = np.random.choice(x_train.size(0), args.batch_size, replace=False) batch_data_enc = torch.bernoulli(x_train[id_]) if sub_iter % 10 == 0: burn_cur_loss = burn_cur_loss / burn_num_examples if burn_pre_loss - burn_cur_loss < 0: break burn_pre_loss = burn_cur_loss burn_cur_loss = burn_num_examples = 0 sub_iter += 1 # print(sub_iter) enc_optimizer.zero_grad() dec_optimizer.zero_grad() loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not aggressive_flag: enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() if iter_ % log_niter == 0: train_loss = (report_rec_loss + report_kl_loss) / report_num_examples if aggressive_flag or epoch == 0: vae.eval() with torch.no_grad(): mi = calc_mi(vae, val_loader) au, _ = calc_au(vae, val_loader) vae.train() print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ 'au %d, time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_examples, mi, report_rec_loss / report_num_examples, au, time.time() - start)) else: print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_examples, report_rec_loss / report_num_examples, time.time() - start)) sys.stdout.flush() report_rec_loss = report_kl_loss = 0 report_num_examples = 0 iter_ += 1 if aggressive_flag and (iter_ % len(train_loader)) == 0: vae.eval() cur_mi = calc_mi(vae, val_loader) vae.train() if cur_mi - best_mi < 0: mi_not_improved += 1 if mi_not_improved == 5: aggressive_flag = False print("STOP BURNING") else: best_mi = cur_mi pre_mi = cur_mi print('kl weight %.4f' % kl_weight) print('epoch: %d, VAL' % epoch) vae.eval() with torch.no_grad(): loss, nll, kl = test(vae, val_loader, "VAL", args) au, au_var = calc_au(vae, val_loader) print("%d active units" % au) # print(au_var) if loss < best_loss: print('update best loss') best_loss = loss best_nll = nll best_kl = kl torch.save(vae.state_dict(), args.save_path) if loss > best_loss: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= decay_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) decay_cnt += 1 print('new lr: %f' % opt_dict["lr"]) enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=opt_dict["lr"]) dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=opt_dict["lr"]) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl = test(vae, test_loader, "TEST", args) vae.train() # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl = test(vae, test_loader, "TEST", args) au, au_var = calc_au(vae, test_loader) print("%d active units" % au) # print(au_var) test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=True) with torch.no_grad(): calc_iwnll(vae, test_loader, args)
type=str, default="mnist", metavar='S', help= 'Training mode selection. Choices: mnist, synthetic_timeseries, cell_timeseries. (default: mnist)' ) args = parser.parse_args() model_filepath = "model-{}.pth".format(args.train_mode) root_path = "results/{}".format(args.train_mode) try: os.makedirs(root_path) except: pass is_cuda = not args.no_cuda device = torch.device("cuda" if is_cuda else "cpu") model = VAE(dropout=args.dropout, input_dim=input_dims[args.train_mode]).to(device) try: model = load_model(model_filepath, model) logger.info("Loading model from {}".format(model_filepath)) except: logger.info("Creating VAE model from scratch") model = VAE(dropout=args.dropout, input_dim=input_dims[args.train_mode]).to(device) if args.train_mode == 'mnist': train_mnist(model, device, args.epochs, root_path) elif args.train_mode == "synthetic_timeseries": model.decoder.sigmoid = False # disable sigmoid from the final decoder layer train_synthetic_timeseries(model, device, args.epochs, root_path) elif args.train_mode == "cell_timeseries": model.decoder.sigmoid = False # disable sigmoid from the final decoder layer train_cell_timeseries(model, device, args.epochs, root_path)
class VAESampler: def __init__(self, decode_from, params, cuda=False): self.decode_from = decode_from self.params = params params.enc_nh = params.dec_nh # not sure why this is necessary... self.train_data = MonoTextData(params.train_data, label=False) self.vocab = self.train_data.vocab self.vocab_size = len(self.vocab) # do I need these? model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) params.device = self.device = torch.device("cuda" if cuda else "cpu") self.encoder = LSTMEncoder(params, self.vocab_size, model_init, emb_init) self.decoder = LSTMDecoder(params, self.vocab, model_init, emb_init) self.vae = VAE(self.encoder, self.decoder, params).to(params.device) # assuming models were trained on a gpu... if cuda: self.vae.load_state_dict(torch.load(self.decode_from)) else: self.vae.load_state_dict( torch.load(self.decode_from, map_location='cpu')) def to_s(self, decoded): return [' '.join(item) for item in decoded] def beam(self, z, K=5): decoded_batch = self.vae.decoder.beam_search_decode(z, K) return self.to_s(decoded_batch) def sample(self, z, temperature=1.0): decoded_batch = self.vae.decoder.sample_decode(z, temperature) return self.to_s(decoded_batch) def greedy(self, z): decoded_batch = self.vae.decoder.greedy_decode(z) return self.to_s(decoded_batch) def str2ids(self, s): "encode string s as list of word ids" raise NotImplemented def encode(self, t): """ Returns (z, mu, log_var) from encoder given list of strings. z is a sample from gaussian specified with (mu, log_var) """ str_ids = [] for s in t: ids = self.str2ids(s) str_ids.append(ids) tensor = self.train_data._to_tensor(str_ids, True, self.device)[0] z, (mu, log_var) = self.vae.encoder.sample(tensor, 1) return z, mu, log_var def z(self, t): "return sampled latent zs for list of strings t" z, mu, logvar = self.encode(t) return z.squeeze(1) def mu(self, t): "return mean of latent gaussian for list of strings t" z, mu, logvar = self.encode(t) return mu.squeeze(1)
def main(args): global logging logging = create_exp_dir(args.exp_dir, scripts_to_save=[]) if args.cuda: logging('using cuda') logging(str(args)) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} vocab = {} with open(args.vocab_file) as fvocab: for i, line in enumerate(fvocab): vocab[line.strip()] = i vocab = VocabEntry(vocab) train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) logging('Train data: %d samples' % len(train_data)) logging('finish reading datasets, vocab size is %d' % len(vocab)) logging('dropped sentences: %d' % train_data.dropped) #sys.stdout.flush() log_niter = max(1, (len(train_data) // (args.batch_size * args.update_every)) // 10) model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) #device = torch.device("cuda" if args.cuda else "cpu") device = "cuda" if args.cuda else "cpu" args.device = device if args.fb == 3: encoder = DeltaGaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh elif args.enc_type == 'lstm': encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) vae = VAE(encoder, decoder, args).to(device) if args.load_path: loaded_state_dict = torch.load(args.load_path) #curr_state_dict = vae.state_dict() #curr_state_dict.update(loaded_state_dict) vae.load_state_dict(loaded_state_dict) logging("%s loaded" % args.load_path) # if args.eval: # logging('begin evaluation') # vae.load_state_dict(torch.load(args.load_path)) # vae.eval() # with torch.no_grad(): # test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, # device=device, # batch_first=True) # test(vae, test_data_batch, test_labels_batch, "TEST", args) # au, au_var = calc_au(vae, test_data_batch) # logging("%d active units" % au) # # print(au_var) # test_data_batch = test_data.create_data_batch(batch_size=1, # device=device, # batch_first=True) # calc_iwnll(vae, test_data_batch, args) # return if args.discriminator == "linear": discriminator = LinearDiscriminator(args, vae.encoder).to(device) elif args.discriminator == "mlp": discriminator = MLPDiscriminator(args, vae.encoder).to(device) if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum) opt_dict['lr'] = args.lr elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=0.001) opt_dict['lr'] = 0.001 else: raise ValueError("optimizer not supported") iter_ = decay_cnt = 0 best_loss = 1e4 # best_kl = best_nll = best_ppl = 0 # pre_mi = 0 discriminator.train() start = time.time() # kl_weight = args.kl_start # if args.warm_up > 0: # anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) # else: # anneal_rate = 0 # dim_target_kl = args.target_kl / float(args.nz) train_data_batch, train_labels_batch = train_data.create_data_batch_labels( batch_size=args.batch_size, device=device, batch_first=True) val_data_batch, val_labels_batch = val_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) test_data_batch, test_labels_batch = test_data.create_data_batch_labels( batch_size=128, device=device, batch_first=True) acc_cnt = 1 acc_loss = 0. for epoch in range(args.epochs): report_loss = 0 report_correct = report_num_words = report_num_sents = 0 acc_batch_size = 0 optimizer.zero_grad() for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] if batch_data.size(0) < 2: continue batch_labels = train_labels_batch[i] batch_labels = [int(x) for x in batch_labels] batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device) batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size acc_batch_size += batch_size # (batch_size) loss, correct = discriminator.get_performance( batch_data, batch_labels) acc_loss = acc_loss + loss.sum() if acc_cnt % args.update_every == 0: acc_loss = acc_loss / acc_batch_size acc_loss.backward() torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_grad) optimizer.step() optimizer.zero_grad() acc_cnt = 0 acc_loss = 0 acc_batch_size = 0 acc_cnt += 1 report_loss += loss.sum().item() report_correct += correct if iter_ % log_niter == 0: #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents train_loss = report_loss / report_num_sents logging('epoch: %d, iter: %d, avg_loss: %.4f, acc %.4f,' \ 'time %.2fs' % (epoch, iter_, train_loss, report_correct / report_num_sents, time.time() - start)) #sys.stdout.flush() iter_ += 1 logging('lr {}'.format(opt_dict["lr"])) print(report_num_sents) discriminator.eval() with torch.no_grad(): loss, acc = test(discriminator, val_data_batch, val_labels_batch, "VAL", args) # print(au_var) if loss < best_loss: logging('update best loss') best_loss = loss best_acc = acc print(args.save_path) torch.save(discriminator.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict[ "not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay discriminator.load_state_dict(torch.load(args.save_path)) logging('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 if args.opt == "sgd": optimizer = optim.SGD(discriminator.parameters(), lr=opt_dict["lr"], momentum=args.momentum) opt_dict['lr'] = opt_dict["lr"] elif args.opt == "adam": optimizer = optim.Adam(discriminator.parameters(), lr=opt_dict["lr"]) opt_dict['lr'] = opt_dict["lr"] else: raise ValueError("optimizer not supported") else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args) discriminator.train() # compute importance weighted estimate of log p(x) discriminator.load_state_dict(torch.load(args.save_path)) discriminator.eval() with torch.no_grad(): loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args)
def main(args): class uniform_initializer(object): def __init__(self, stdv): self.stdv = stdv def __call__(self, tensor): nn.init.uniform_(tensor, -self.stdv, self.stdv) class xavier_normal_initializer(object): def __call__(self, tensor): nn.init.xavier_normal_(tensor) if args.cuda: print('using cuda') print(args) opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} train_data = MonoTextData(args.train_data, label=args.label) vocab = train_data.vocab vocab_size = len(vocab) val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) print('Train data: %d samples' % len(train_data)) print('finish reading datasets, vocab size is %d' % len(vocab)) print('dropped sentences: %d' % train_data.dropped) sys.stdout.flush() log_niter = (len(train_data) // args.batch_size) // 10 model_init = uniform_initializer(0.01) emb_init = uniform_initializer(0.1) if args.enc_type == 'lstm': encoder = LSTMEncoder(args, vocab_size, model_init, emb_init) args.enc_nh = args.dec_nh else: raise ValueError("the specified encoder type is not supported") decoder = LSTMDecoder(args, vocab, model_init, emb_init) device = torch.device("cuda" if args.cuda else "cpu") args.device = device vae = VAE(encoder, decoder, args).to(device) if args.eval: print('begin evaluation') vae.load_state_dict(torch.load(args.load_path)) vae.eval() with torch.no_grad(): test_data_batch = test_data.create_data_batch( batch_size=args.batch_size, device=device, batch_first=True) test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) print("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) calc_iwnll(vae, test_data_batch, args) return enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=1.0, momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=1.0, momentum=args.momentum) opt_dict['lr'] = 1.0 iter_ = decay_cnt = 0 best_loss = 1e4 best_kl = best_nll = best_ppl = 0 pre_mi = 0 aggressive_flag = True if args.aggressive else False vae.train() start = time.time() kl_weight = args.kl_start anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, device=device, batch_first=True) for epoch in range(args.epochs): report_kl_loss = report_rec_loss = 0 report_num_words = report_num_sents = 0 for i in np.random.permutation(len(train_data_batch)): batch_data = train_data_batch[i] batch_size, sent_len = batch_data.size() # not predict start symbol report_num_words += (sent_len - 1) * batch_size report_num_sents += batch_size # kl_weight = 1.0 kl_weight = min(1.0, kl_weight + anneal_rate) sub_iter = 1 batch_data_enc = batch_data burn_num_words = 0 burn_pre_loss = 1e4 burn_cur_loss = 0 while aggressive_flag and sub_iter < 100: enc_optimizer.zero_grad() dec_optimizer.zero_grad() burn_batch_size, burn_sents_len = batch_data_enc.size() burn_num_words += (burn_sents_len - 1) * burn_batch_size loss, loss_rc, loss_kl = vae.loss(batch_data_enc, kl_weight, nsamples=args.nsamples) burn_cur_loss += loss.sum().item() loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) enc_optimizer.step() id_ = np.random.random_integers(0, len(train_data_batch) - 1) batch_data_enc = train_data_batch[id_] if sub_iter % 15 == 0: burn_cur_loss = burn_cur_loss / burn_num_words if burn_pre_loss - burn_cur_loss < 0: break burn_pre_loss = burn_cur_loss burn_cur_loss = burn_num_words = 0 sub_iter += 1 # if sub_iter >= 30: # break # print(sub_iter) enc_optimizer.zero_grad() dec_optimizer.zero_grad() loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) loss = loss.mean(dim=-1) loss.backward() torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) loss_rc = loss_rc.sum() loss_kl = loss_kl.sum() if not aggressive_flag: enc_optimizer.step() dec_optimizer.step() report_rec_loss += loss_rc.item() report_kl_loss += loss_kl.item() if iter_ % log_niter == 0: train_loss = (report_rec_loss + report_kl_loss) / report_num_sents if aggressive_flag or epoch == 0: vae.eval() with torch.no_grad(): mi = calc_mi(vae, val_data_batch) au, _ = calc_au(vae, val_data_batch) vae.train() print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f,' \ 'au %d, time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, mi, report_rec_loss / report_num_sents, au, time.time() - start)) else: print('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 'time elapsed %.2fs' % (epoch, iter_, train_loss, report_kl_loss / report_num_sents, report_rec_loss / report_num_sents, time.time() - start)) sys.stdout.flush() report_rec_loss = report_kl_loss = 0 report_num_words = report_num_sents = 0 iter_ += 1 if aggressive_flag and (iter_ % len(train_data_batch)) == 0: vae.eval() cur_mi = calc_mi(vae, val_data_batch) vae.train() print("pre mi:%.4f. cur mi:%.4f" % (pre_mi, cur_mi)) if cur_mi - pre_mi < 0: aggressive_flag = False print("STOP BURNING") pre_mi = cur_mi print('kl weight %.4f' % kl_weight) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) au, au_var = calc_au(vae, val_data_batch) print("%d active units" % au) # print(au_var) if loss < best_loss: print('update best loss') best_loss = loss best_nll = nll best_kl = kl best_ppl = ppl torch.save(vae.state_dict(), args.save_path) if loss > opt_dict["best_loss"]: opt_dict["not_improved"] += 1 if opt_dict["not_improved"] >= decay_epoch and epoch >= 15: opt_dict["best_loss"] = loss opt_dict["not_improved"] = 0 opt_dict["lr"] = opt_dict["lr"] * lr_decay vae.load_state_dict(torch.load(args.save_path)) print('new lr: %f' % opt_dict["lr"]) decay_cnt += 1 enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) else: opt_dict["not_improved"] = 0 opt_dict["best_loss"] = loss if decay_cnt == max_decay: break if epoch % args.test_nepoch == 0: with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) vae.train() # compute importance weighted estimate of log p(x) vae.load_state_dict(torch.load(args.save_path)) vae.eval() with torch.no_grad(): loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) au, au_var = calc_au(vae, test_data_batch) print("%d active units" % au) # print(au_var) test_data_batch = test_data.create_data_batch(batch_size=1, device=device, batch_first=True) with torch.no_grad(): calc_iwnll(vae, test_data_batch, args)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file).") parser.add_argument("--eval_data_file", default=None, type=str, help="An input evaluation data file to evaluate the perplexity on (a text file).") parser.add_argument("--checkpoint_dir", default=None, type=str, required=True, help="The directory where checkpoints are saved.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--dataset", default='Snli', type=str, help="The dataset.") ## Variational auto-encoder parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.") parser.add_argument("--total_sents", default=10, type=int, help="Total sentences to test recontruction.") parser.add_argument("--num_sents", default=10, type=int, help="Total sentences to generate.") ## Encoder options parser.add_argument("--encoder_model_type", default="bert", type=str, help="The encoder model architecture to be fine-tuned.") parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str, help="The encoder model checkpoint for weights initialization.") parser.add_argument("--encoder_config_name", default="", type=str, help="Optional pretrained config name or path if not the same as model_name_or_path") parser.add_argument("--encoder_tokenizer_name", default="", type=str, help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") ## Decoder options parser.add_argument("--decoder_model_type", default="gpt2", type=str, help="The decoder model architecture to be fine-tuned.") parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str, help="The decoder model checkpoint for weights initialization.") parser.add_argument("--decoder_config_name", default="", type=str, help="Optional pretrained config name or path if not the same as model_name_or_path") parser.add_argument("--decoder_tokenizer_name", default="", type=str, help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") parser.add_argument("--per_gpu_train_batch_size", default=1, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument('--gloabl_step_eval', type=int, default=661, help="Evaluate the results at the given global step") parser.add_argument("--max_seq_length", default=512, type=int, help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length") ## Variational auto-encoder parser.add_argument("--nz", default=32, type=int, help="Latent space dimension.") parser.add_argument("--prompt", type=str, default="") parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--length", type=int, default=20) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument("--block_size", default=-1, type=int, help="Optional input sequence length after tokenization." "The training dataset will be truncated in block of this size for training." "Default to the model max input length for single sentence inputs (take into account special tokens).") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--use_philly", action='store_true', help="Use Philly for computing.") args = parser.parse_args() args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() set_seed(args) args.encoder_model_type = args.encoder_model_type.lower() args.decoder_model_type = args.decoder_model_type.lower() global_step = args.gloabl_step_eval output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step)) output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step)) checkpoints = [ [output_encoder_dir, output_decoder_dir] ] logger.info("Evaluate the following checkpoints: %s", checkpoints) # Load a trained Encoder model and vocabulary that you have fine-tuned encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type] model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size) tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case) model_encoder.to(args.device) if args.block_size <= 0: args.block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence) # Load a trained Decoder model and vocabulary that you have fine-tuned decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type] model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size) tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case) model_decoder.to(args.device) if args.block_size <= 0: args.block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence) # pdb.set_trace() # Chunyuan: Add Padding token to GPT2 special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'} num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict) print('We have added', num_added_toks, 'tokens to GPT2') model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. assert tokenizer_decoder.pad_token == '<PAD>' # Evaluation model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) args.output_generation_file = os.path.join(args.output_dir, f"generation_from_vae_prior_t{args.temperature}_p{args.top_p}.txt") # args.output_generation_file = args.train_data_file result = evaluate_generation_fromp_prior(model_vae, tokenizer_decoder, args) bleu5 = Bleu(test_text= args.output_generation_file, real_text=args.eval_data_file, num_real_sentences=args.num_sents, num_fake_sentences=args.num_sents, gram=5).get_score() logger.info(f'The bleu score is {bleu5}') sbleu5 = SelfBleu(test_text= args.output_generation_file, num_sentences=args.num_sents, gram=5).get_score() logger.info(f'The self-bleu score is {sbleu5}') args.eval_results_file = os.path.join(args.output_dir, f"eval_results_t{args.temperature}_p{args.top_p}.txt") eval_results = {'bleu5':bleu5 , 'sbleu5':sbleu5} with open(args.eval_results_file, "w") as writer: logger.info("***** SHOW the quantative evalution results *****") for key in sorted(eval_results.keys()): writer.write("%s %s" % (key, str(eval_results[key])) )