def main(): args = make_parser().parse_args() print("[Model hyperparams]: {}".format(str(args))) cuda = torch.cuda.is_available() and args.cuda device = torch.device("cpu") if not cuda else torch.device("cuda:" + str(args.gpu)) seed_everything(seed=1337, cuda=cuda) vectors = None #don't use pretrained vectors # vectors = load_pretrained_vectors(args.emsize) # Load dataset iterators iters, TEXT, LABEL, PROMPTS = dataset_map[args.data]( args.batch_size, device=device, vectors=vectors, base_path=args.base_path) # Some datasets just have the train & test sets, so we just pretend test is valid if len(iters) >= 4: train_iter = iters[0] val_iter = iters[1] test_iter = iters[2] outdomain_test_iter = list(iters[3:]) elif len(iters) == 3: train_iter, val_iter, test_iter = iters else: train_iter, test_iter = iters val_iter = test_iter print("[Corpus]: train: {}, test: {}, vocab: {}, labels: {}".format( len(train_iter.dataset), len(test_iter.dataset), len(TEXT.vocab), len(LABEL.vocab))) if args.model == "CNN": args.embed_num = len(TEXT.vocab) args.nlabels = len(LABEL.vocab) args.nprompts = len(PROMPTS.vocab) args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] args.embed_dim = args.emsize classifier_model = CNN_Text_GANLike(args) topic_decoder = [ nn.Sequential( nn.Linear( len(args.kernel_sizes) * args.kernel_num, args.nprompts)) ] else: ntokens, nlabels, nprompts = len(TEXT.vocab), len(LABEL.vocab), len( PROMPTS.vocab) args.nlabels = nlabels # hack to not clutter function arguments args.nprompts = nprompts embedding = nn.Embedding(ntokens, args.emsize, padding_idx=1) encoder = Encoder(args.emsize, args.hidden, nlayers=args.nlayers, dropout=args.drop, bidirectional=args.bi, rnn_type=args.rnn_model) attention_dim = args.hidden if not args.bi else 2 * args.hidden attention = BahdanauAttention(attention_dim, attention_dim) if args.bottleneck_dim == 0: classifier_model = Classifier_GANLike(embedding, encoder, attention, attention_dim, nlabels) topic_decoder = [ nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(attention_dim, args.nprompts)) ] else: classifier_model = Classifier_GANLike_bottleneck( embedding, encoder, attention, attention_dim, nlabels, bottleneck_dim=args.bottleneck_dim) topic_decoder = [ nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(args.bottleneck_dim, args.nprompts)) ] classifier_model.to(device) topic_decoder[0].to(device) classify_criterion = nn.CrossEntropyLoss() topic_criterion = nn.CrossEntropyLoss() classify_optim = Optim(args.optim, args.lr, args.clip) topic_optim = Optim(args.optim, args.lr, args.clip) for p in classifier_model.parameters(): if not p.requires_grad: print("OMG", p) p.requires_grad = True p.data.uniform_(-args.param_init, args.param_init) for p in topic_decoder[0].parameters(): if not p.requires_grad: print("OMG", p) p.requires_grad = True p.data.uniform_(-args.param_init, args.param_init) classify_optim.set_parameters(classifier_model.parameters()) topic_optim.set_parameters(topic_decoder[0].parameters()) if args.load: if args.latest: best_model = torch.load(args.save_dir + "/" + args.model_name + "_latestmodel") else: best_model = torch.load(args.save_dir + "/" + args.model_name + "_bestmodel") else: try: best_valid_loss = None best_model = None #pretraining the classifier for epoch in range(1, args.pretrain_epochs + 1): pretrain_classifier(classifier_model, train_iter, classify_optim, classify_criterion, args, epoch) loss = evaluate(classifier_model, val_iter, classify_criterion, args) # oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest") if not best_valid_loss or loss < best_valid_loss: best_valid_loss = loss print("Updating best pretrained_model") best_model = copy.deepcopy(classifier_model) torch.save( best_model, args.save_dir + "/" + args.model_name + "_pretrained_bestmodel") torch.save( classifier_model, args.save_dir + "/" + args.model_name + "_pretrained_latestmodel") # classifier_model = best_model #alternating training like GANs for epoch in range(1, args.epochs + 1): for t_step in range(1, args.t_steps + 1): train_topic_predictor(classifier_model, topic_decoder[-1], train_iter, topic_optim, topic_criterion, args, epoch, args.t_steps) if args.reset_classifier: for p in classifier_model.parameters(): if not p.requires_grad: print("OMG", p) p.requires_grad = True p.data.uniform_(-args.param_init, args.param_init) for c_step in range(1, args.c_steps + 1): train_classifier(classifier_model, topic_decoder, train_iter, classify_optim, classify_criterion, topic_criterion, args, epoch, args.c_steps) loss = evaluate(classifier_model, val_iter, classify_criterion, args) # oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest") #creating a new instance of a decoder if args.model == "CNN": topic_decoder += [ nn.Sequential( nn.Linear( len(args.kernel_sizes) * args.kernel_num, args.nprompts)) ] else: attention_dim = args.hidden if not args.bi else 2 * args.hidden if args.bottleneck_dim == 0: topic_decoder.append( nn.Sequential( nn.Dropout(args.topic_drop), nn.Linear(attention_dim, args.nprompts))) else: topic_decoder.append( nn.Sequential( nn.Dropout(args.topic_drop), nn.Linear(args.bottleneck_dim, args.nprompts))) #attaching a new optimizer to the new topic decode topic_decoder[-1].to(device) topic_optim = Optim(args.optim, args.lr, args.clip) for p in topic_decoder[-1].parameters(): if not p.requires_grad: print("OMG", p) p.requires_grad = True p.data.uniform_(-args.param_init, args.param_init) topic_optim.set_parameters(topic_decoder[-1].parameters()) if not best_valid_loss or loss < best_valid_loss: best_valid_loss = loss print("Updating best model") best_model = copy.deepcopy(classifier_model) torch.save( best_model, args.save_dir + "/" + args.model_name + "_bestmodel") torch.save( classifier_model, args.save_dir + "/" + args.model_name + "_latestmodel") except KeyboardInterrupt: print("[Ctrl+C] Training stopped!") # if not args.load: trainloss = evaluate(best_model, train_iter, classify_criterion, args, datatype='train', writetopics=args.save_output_topics, itos=TEXT.vocab.itos, litos=LABEL.vocab.itos) valloss = evaluate(best_model, val_iter, classify_criterion, args, datatype='valid', writetopics=args.save_output_topics, itos=TEXT.vocab.itos, litos=LABEL.vocab.itos) loss = evaluate(best_model, test_iter, classify_criterion, args, datatype='test', writetopics=args.save_output_topics, itos=TEXT.vocab.itos, litos=LABEL.vocab.itos) if args.data == "AMAZON": oodnames = args.oodname.split(",") for oodname, oodtest_iter in zip(oodnames, outdomain_test_iter): oodLoss = evaluate(best_model, oodtest_iter, classify_criterion, args, datatype=oodname + "_bestmodel", writetopics=args.save_output_topics) oodLoss = evaluate(classifier_model, oodtest_iter, classify_criterion, args, datatype=oodname + "_latest", writetopics=args.save_output_topics) else: oodLoss = evaluate(best_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest_bestmodel", writetopics=args.save_output_topics, itos=TEXT.vocab.itos, litos=LABEL.vocab.itos) oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest_latest", writetopics=args.save_output_topics)
def main(): print("Loading data from '%s'" % opt.data) dataset = torch.load(opt.data) dict_checkpoint = opt.train_from if opt.train_from else opt.train_from_state_dict if dict_checkpoint: print('Loading dicts from checkpoint at %s' % dict_checkpoint) checkpoint = torch.load(dict_checkpoint) dataset['dicts'] = checkpoint['dicts'] trainData = Dataset(dataset['train']['src'], dataset['train']['tgt'], opt.batch_size, opt.gpus) validData = Dataset(dataset['valid']['src'], dataset['valid']['tgt'], opt.batch_size, opt.gpus, volatile=True) dicts = dataset['dicts'] print(' * vocabulary size. source = %d; target = %d' % (len(dicts["word2index"]['src']), len(dicts["word2index"]['tgt']))) print(' * number of training sentences. %d' % len(dataset['train']['src'])) print(' * maximum batch size. %d' % opt.batch_size) print('Building model...') encoder = Encoder(opt, len(dicts["word2index"]['src'])) decoder = Decoder(opt, len(dicts["word2index"]['tgt'])) generator = nn.Sequential( nn.Linear(opt.hidden_size * 2, len(dicts["word2index"]['tgt'])), nn.LogSoftmax()) model = NMTModel(encoder, decoder) if opt.train_from: print('Loading model from checkpoint at %s' % opt.train_from) chk_model = checkpoint['model'] generator_state_dict = chk_model.generator.state_dict() model_state_dict = { k: v for k, v in chk_model.state_dict().items() if 'generator' not in k } model.load_state_dict(model_state_dict) generator.load_state_dict(generator_state_dict) opt.start_epoch = checkpoint['epoch'] + 1 if opt.train_from_state_dict: print('Loading model from checkpoint at %s' % opt.train_from_state_dict) model.load_state_dict(checkpoint['model']) generator.load_state_dict(checkpoint['generator']) opt.start_epoch = checkpoint['epoch'] + 1 if len(opt.gpus) >= 1: model.cuda() generator.cuda() else: model.cpu() generator.cpu() if len(opt.gpus) > 1: model = nn.DataParallel(model, device_ids=opt.gpus, dim=1) generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0) model.generator = generator if not opt.train_from_state_dict and not opt.train_from: for p in model.parameters(): p.data.uniform_(-opt.param_init, opt.param_init) encoder.load_pretrained_vectors(opt) decoder.load_pretrained_vectors(opt) optim = Optim(opt.optim, opt.learning_rate, opt.max_grad_norm, lr_decay=opt.learning_rate_decay, start_decay_at=opt.start_decay_at) else: print('Loading optimizer from checkpoint:') optim = checkpoint['optim'] print(optim) optim.set_parameters(model.parameters()) if opt.train_from or opt.train_from_state_dict: optim.optimizer.load_state_dict( checkpoint['optim'].optimizer.state_dict()) nParams = sum([p.nelement() for p in model.parameters()]) print('* number of parameters: %d' % nParams) criterion = NMTCriterion(len(dicts["word2index"]['tgt'])) trainModel(model, trainData, validData, dataset, optim, criterion)
def main(hparams: HParams): ''' setup training. ''' if torch.cuda.is_available() and not hparams.gpus: warnings.warn( 'WARNING: you have a CUDA device, so you should probably run with -gpus 0' ) device = torch.device(hparams.gpus if torch.cuda.is_available() else 'cpu') # data setup print(f"Loading vocabulary...") text_preprocessor = TextPreprocessor.load(hparams.preprocessor_path) transform = transforms.Compose([ transforms.Resize([hparams.img_size, hparams.img_size]), transforms.RandomCrop([hparams.crop_size, hparams.crop_size]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # create dataloader print('Creating DataLoader...') normal_data_loader = get_image_caption_loader( hparams.img_dir, hparams.normal_caption_path, text_preprocessor, hparams.normal_batch_size, transform, shuffle=True, num_workers=hparams.num_workers, ) style_data_loader = get_caption_loader( hparams.style_caption_path, text_preprocessor, batch_size=hparams.style_batch_size, shuffle=True, num_workers=hparams.num_workers, ) if hparams.train_from: # loading checkpoint print('Loading checkpoint...') checkpoint = torch.load(hparams.train_from) else: normal_opt = Optim( hparams.optimizer, hparams.normal_lr, hparams.max_grad_norm, hparams.lr_decay, hparams.start_decay_at, ) style_opt = Optim( hparams.optimizer, hparams.style_lr, hparams.max_grad_norm, hparams.lr_decay, hparams.start_decay_at, ) print('Building model...') encoder = EncoderCNN(hparams.hidden_dim) decoder = FactoredLSTM(hparams.embed_dim, text_preprocessor.vocab_size, hparams.hidden_dim, hparams.style_dim, hparams.num_layers, hparams.random_init, hparams.dropout_ratio, train=True, device=device) encoder = encoder.to(device) decoder = decoder.to(device) # loss and optimizer criterion = nn.CrossEntropyLoss(ignore_index=text_preprocessor.PAD_ID) normal_params = list(encoder.parameters()) + list( decoder.default_parameters()) style_params = list(decoder.style_parameters()) normal_opt.set_parameters(normal_params) style_opt.set_parameters(style_params) if hparams.train_from: encoder.load_state_dict(checkpoint['encoder']) decoder.load_state_dict(checkpoint['decoder']) normal_opt.load_state_dict(checkpoint['normal_opt']) style_opt.load_state_dict(checkpoint['style_opt']) # traininig loop print('Start training...') for epoch in range(hparams.num_epoch): # result sum_normal_loss, sum_style_loss, sum_normal_ppl, sum_style_ppl = 0, 0, 0, 0 # normal caption for i, (images, in_captions, out_captions, lengths) in enumerate(normal_data_loader): images = images.to(device) in_captions = in_captions.to(device) out_captions = out_captions.contiguous().view(-1).to(device) # Forward, backward and optimize features = encoder(images) outputs = decoder(in_captions, features, mode='default') loss = criterion(outputs.view(-1, outputs.size(-1)), out_captions) encoder.zero_grad() decoder.zero_grad() loss.backward() normal_opt.step() # print log sum_normal_loss += loss.item() sum_normal_ppl += np.exp(loss.item()) if i % hparams.normal_log_step == 0: print( f'Epoch [{epoch}/{hparams.num_epoch}], Normal Step: [{i}/{len(normal_data_loader)}] ' f'Normal Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}' ) # style caption for i, (in_captions, out_captions, lengths) in enumerate(style_data_loader): in_captions = in_captions.to(device) out_captions = out_captions.contiguous().view(-1).to(device) # Forward, backward and optimize outputs = decoder(in_captions, None, mode='style') loss = criterion(outputs.view(-1, outputs.size(-1)), out_captions) decoder.zero_grad() loss.backward() style_opt.step() sum_style_loss += loss.item() sum_style_ppl += np.exp(loss.item()) # print log if i % hparams.style_log_step == 0: print( f'Epoch [{epoch}/{hparams.num_epoch}], Style Step: [{i}/{len(style_data_loader)}] ' f'Style Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}' ) model_params = { 'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'epoch': epoch, 'normal_opt': normal_opt.optimizer.state_dict(), 'style_opt': style_opt.optimizer.state_dict(), } avg_normal_loss = sum_normal_loss / len(normal_data_loader) avg_style_loss = sum_style_loss / len(style_data_loader) avg_normal_ppl = sum_normal_ppl / len(normal_data_loader) avg_style_ppl = sum_style_ppl / len(style_data_loader) print(f'Epoch [{epoch}/{hparams.num_epoch}] statistics') print( f'Normal Loss: {avg_normal_loss:.4f} Normal ppl: {avg_normal_ppl:5.4f} ' f'Style Loss: {avg_style_loss:.4f} Style ppl: {avg_style_ppl:5.4f}' ) torch.save( model_params, f'{hparams.model_path}/n-loss_{avg_normal_loss:.4f}_s-loss_{avg_style_loss:.4f}_' f'n-ppl_{avg_normal_ppl:5.4f}_s-ppl_{avg_style_ppl:5.4f}_epoch_{epoch}.pt' )
def main(): args = make_parser().parse_args() print("[Model hyperparams]: {}".format(str(args))) cuda = torch.cuda.is_available() and args.cuda device = torch.device("cpu") if not cuda else torch.device("cuda:"+str(args.gpu)) seed_everything(seed=1337, cuda=cuda) vectors = None #don't use pretrained vectors # vectors = load_pretrained_vectors(args.emsize) # Load dataset iterators if args.data in ["RT_GENDER"]: if args.finetune: iters, TEXT, LABEL, INDEX = make_rt_gender(args.batch_size, base_path=args.base_path, train_file=args.train_file, valid_file=args.valid_file, test_file=args.test_file, device=device, vectors=vectors, topics=False) else: iters, TEXT, LABEL, TOPICS, INDEX = make_rt_gender(args.batch_size, base_path=args.base_path, train_file=args.train_file, valid_file=args.valid_file, test_file=args.test_file, device=device, vectors=vectors, topics=True) train_iter, val_iter, test_iter = iters else: assert False if not args.finetune: for batch in train_iter: args.num_topics = batch.topics.shape[1] break print("[Corpus]: train: {}, test: {}, vocab: {}, labels: {}".format( len(train_iter.dataset), len(test_iter.dataset), len(TEXT.vocab), len(LABEL.vocab))) if args.model == "CNN": args.embed_num = len(TEXT.vocab) args.nlabels = len(LABEL.vocab) args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] args.embed_dim = args.emsize classifier_model = CNN_Text_GANLike(args) topic_decoder = nn.Sequential(nn.Linear(len(args.kernel_sizes)*args.kernel_num, args.num_topics), nn.LogSoftmax(dim=-1)) else: ntokens, nlabels = len(TEXT.vocab), len(LABEL.vocab) args.nlabels = nlabels # hack to not clutter function arguments embedding = nn.Embedding(ntokens, args.emsize, padding_idx=1) encoder = Encoder(args.emsize, args.hidden, nlayers=args.nlayers, dropout=args.drop, bidirectional=args.bi, rnn_type=args.rnn_model) attention_dim = args.hidden if not args.bi else 2*args.hidden attention = BahdanauAttention(attention_dim, attention_dim) if args.bottleneck_dim == 0: classifier_model = Classifier_GANLike(embedding, encoder, attention, attention_dim, nlabels) topic_decoder = [nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(attention_dim, args.num_topics), nn.LogSoftmax())] else: classifier_model = Classifier_GANLike_bottleneck(embedding, encoder, attention, attention_dim, nlabels, bottleneck_dim=args.bottleneck_dim) topic_decoder = [nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(args.bottleneck_dim, args.num_topics), nn.LogSoftmax())] classifier_model.to(device) topic_decoder[0].to(device) classify_criterion = nn.CrossEntropyLoss() topic_criterion = nn.KLDivLoss(size_average=False) classify_optim = Optim(args.optim, args.lr, args.clip) topic_optim = Optim(args.optim, args.lr, args.clip) for p in classifier_model.parameters(): if not p.requires_grad: print ("OMG", p) p.requires_grad = True p.data.uniform_(-args.param_init, args.param_init) for p in topic_decoder[0].parameters(): if not p.requires_grad: print ("OMG", p) p.requires_grad = True p.data.uniform_(-args.param_init, args.param_init) classify_optim.set_parameters(classifier_model.parameters()) topic_optim.set_parameters(topic_decoder[0].parameters()) if args.load: if args.latest: best_model = torch.load(args.save_dir+"/"+args.model_name+"_latestmodel") else: best_model = torch.load(args.save_dir+"/"+args.model_name+"_bestmodel") else: try: best_valid_loss = None best_model = None #pretraining the classifier for epoch in range(1, args.pretrain_epochs+1): pretrain_classifier(classifier_model, train_iter, classify_optim, classify_criterion, args, epoch) loss = evaluate(classifier_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args) #oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest") if not best_valid_loss or loss < best_valid_loss: best_valid_loss = loss print ("Updating best pretrained_model") best_model = copy.deepcopy(classifier_model) torch.save(best_model, args.save_dir+"/"+args.model_name+"_pretrained_bestmodel") torch.save(classifier_model, args.save_dir+"/"+args.model_name+"_pretrained_latestmodel") print("Done pretraining") print() best_valid_loss = None best_model = None #alternating training like GANs for epoch in range(1, args.epochs + 1): for t_step in range(1, args.t_steps+1): print() print("Training topic predictor") train_topic_predictor(classifier_model, topic_decoder[-1], train_iter, topic_optim, topic_criterion, args, epoch, args.t_steps) if args.reset_classifier: for p in classifier_model.parameters(): if not p.requires_grad: print ("OMG", p) p.requires_grad = True p.data.uniform_(-args.param_init, args.param_init) for c_step in range(1, args.c_steps+1): print() print("Training classifier") train_classifier(classifier_model, topic_decoder, train_iter, classify_optim, classify_criterion, topic_criterion, args, epoch, args.c_steps) loss = evaluate(classifier_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args) #oodLoss = evaluate(classifier_model, outdomain_test_iter[0], classify_criterion, args, datatype="oodtest") #creating a new instance of a decoder attention_dim = args.hidden if not args.bi else 2*args.hidden if args.bottleneck_dim == 0: topic_decoder.append(nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(attention_dim, args.num_topics), nn.LogSoftmax())) else: topic_decoder.append(nn.Sequential(nn.Dropout(args.topic_drop), nn.Linear(args.bottleneck_dim, args.num_topics), nn.LogSoftmax())) #attaching a new optimizer to the new topic decode topic_decoder[-1].to(device) topic_optim = Optim(args.optim, args.lr, args.clip) for p in topic_decoder[-1].parameters(): if not p.requires_grad: print ("OMG", p) p.requires_grad = True p.data.uniform_(-args.param_init, args.param_init) topic_optim.set_parameters(topic_decoder[-1].parameters()) if not best_valid_loss or loss < best_valid_loss: best_valid_loss = loss print ("Updating best model") best_model = copy.deepcopy(classifier_model) torch.save(best_model, args.save_dir+"/"+args.model_name+"_bestmodel") torch.save(classifier_model, args.save_dir+"/"+args.model_name+"_latestmodel") except KeyboardInterrupt: print("[Ctrl+C] Training stopped!") if args.finetune: best_valid_loss = None for c_step in range(1, args.c_steps+1): print() print("Fine-tuning classifier") train_classifier(classifier_model, None, train_iter, classify_optim, classify_criterion, None, args, c_step, args.c_steps) loss = evaluate(classifier_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args) if not best_valid_loss or loss < best_valid_loss: best_valid_loss = loss print ("Updating best model") best_model = copy.deepcopy(classifier_model) torch.save(best_model, args.save_dir+"/"+args.model_name+"finetune_bestmodel") torch.save(classifier_model, args.save_dir+"/"+args.model_name+"finetune_latestmodel") if not args.load: trainloss = evaluate(best_model, topic_decoder, train_iter, classify_criterion, topic_criterion, args, datatype='train', itos=TEXT.vocab.itos, litos=LABEL.vocab.itos) valloss = evaluate(best_model, topic_decoder, val_iter, classify_criterion, topic_criterion, args, datatype='valid', itos=TEXT.vocab.itos, litos=LABEL.vocab.itos) loss = evaluate(best_model, topic_decoder, test_iter, classify_criterion, topic_criterion, args, datatype=os.path.basename(args.test_file).replace(".txt", "").replace(".tsv", ""), itos=TEXT.vocab.itos, litos=LABEL.vocab.itos) if args.ood_test_file: loss = evaluate(best_model, topic_decoder, test_iter, classify_criterion, topic_criterion, args, datatype=os.path.basename(args.ood_test_file).replace(".txt", "").replace(".tsv", ""), itos=TEXT.vocab.itos, litos=LABEL.vocab.itos)