def main(opt_): if opt_.pretrained: net, dictionary = load_model(opt_.pretrained, opt_) net.opt.dataset_name = opt_.dataset_name net.opt.reddit_folder = opt_.reddit_folder net.opt.reactonly = opt_.reactonly net.opt.max_hist_len = opt_.max_hist_len env = TrainEnvironment(net.opt, dictionary) if opt_.cuda: net = torch.nn.DataParallel(net.cuda()) valid_data = env.build_valid_dataloader(False) test_data = env.build_valid_dataloader(False, test=True) with torch.no_grad(): logging.info("Validating on the valid set -unshuffled") validate( 0, net, valid_data, is_test=False, nb_candidates=opt_.hits_at_nb_cands, is_shuffled=False, ) logging.info("Validating on the hidden test set -unshuffled") validate( 0, net, test_data, is_test=True, nb_candidates=opt_.hits_at_nb_cands, is_shuffled=False, ) valid_data = env.build_valid_dataloader(True) test_data = env.build_valid_dataloader(True, test=True) with torch.no_grad(): logging.info("Validating on the valid set -shuffle") validate( 0, net, valid_data, is_test=False, nb_candidates=opt_.hits_at_nb_cands, is_shuffled=True, ) logging.info("Validating on the hidden test set -shuffle") validate( 0, net, test_data, is_test=True, nb_candidates=opt_.hits_at_nb_cands, is_shuffled=True, ) else: train_model(opt_)
def train_model(opt_): env = TrainEnvironment(opt_) dictionary = env.dict if opt_.load_checkpoint: net, dictionary = load_model(opt_.load_checkpoint, opt_) env = TrainEnvironment(opt_, dictionary) env.dict = dictionary else: net = create_model(opt_, dictionary["words"]) if opt_.embeddings and opt_.embeddings != "None": load_embeddings(opt_, dictionary["words"], net) paramnum = 0 trainable = 0 for name, parameter in net.named_parameters(): if parameter.requires_grad: trainable += parameter.numel() paramnum += parameter.numel() print("TRAINABLE", paramnum, trainable) if opt_.cuda: net = torch.nn.DataParallel(net) net = net.cuda() if opt_.optimizer == "adamax": lr = opt_.learning_rate or 0.002 named_params_to_optimize = filter(lambda p: p[1].requires_grad, net.named_parameters()) params_to_optimize = (p[1] for p in named_params_to_optimize) optimizer = optim.Adamax(params_to_optimize, lr=lr) if opt_.epoch_start != 0: saved_params = torch.load( opt_.load_checkpoint, map_location=lambda storage, loc: storage) optimizer.load_state_dict(saved_params["optim_dict"]) else: lr = opt_.learning_rate or 0.01 optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr) start_time = time.time() best_loss = float("+inf") test_data_shuffled = env.build_valid_dataloader(True) test_data_not_shuffled = env.build_valid_dataloader(False) with torch.no_grad(): validate( 0, net, test_data_shuffled, nb_candidates=opt_.hits_at_nb_cands, shuffled_str="shuffled", ) train_data = None for epoch in range(opt_.epoch_start, opt_.num_epochs): if train_data is None or opt_.dataset_name == "reddit": train_data = env.build_train_dataloader(epoch) train(epoch, start_time, net, optimizer, opt_, train_data) with torch.no_grad(): # We compute the loss both for shuffled and not shuffled case. # however, the loss that determines if the model is better is the # same as the one used for training. loss_shuffled = validate( epoch, net, test_data_shuffled, nb_candidates=opt_.hits_at_nb_cands, shuffled_str="shuffled", ) loss_not_shuffled = validate( epoch, net, test_data_not_shuffled, nb_candidates=opt_.hits_at_nb_cands, shuffled_str="not-shuffled", ) if opt_.no_shuffle: loss = loss_not_shuffled else: loss = loss_shuffled if loss < best_loss: best_loss = loss best_loss_epoch = epoch logging.info( f"New best loss, saving model to {opt_.model_file}") save_model(opt_.model_file, net, dictionary, optimizer) # Stop if it's been too many epochs since the loss has decreased if opt_.stop_crit_num_epochs != -1: if epoch - best_loss_epoch >= opt_.stop_crit_num_epochs: break return net, dictionary
"--task", type=str, choices=["dailydialog", "empchat", "reddit"], default="empchat", help="Dataset for context/target-response pairs", ) args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: torch.cuda.set_device(args.gpu) logger.info(f"CUDA enabled (GPU {args.gpu:d})") else: logger.info("Running on CPU only.") if args.fasttext is not None: args.max_cand_length += args.fasttext net, net_dictionary = load_model(args.model, get_opt(existing_opt=args)) if "bert_tokenizer" in net_dictionary: if args.task == "dailydialog": raise NotImplementedError( "BERT model currently incompatible with DailyDialog!") if args.bleu_dict is not None: _, bleu_dictionary = load_model(args.bleu_dict, get_opt(existing_opt=args)) else: bleu_dictionary = net_dictionary paramnum = 0 trainable = 0 for parameter in net.parameters(): if parameter.requires_grad: trainable += parameter.numel() paramnum += parameter.numel() print(paramnum, trainable)