def get_model(corpus, model_type_to_word, model_type, args, attention_model=False): if model_type_to_word == None: model_type_to_word = get_double_input_rnn_model(corpus, args) if model_type == None: model_type = get_rnn_model(corpus, args) model = combined_model(model_type_to_word, model_type) criterion = nn.CrossEntropyLoss() if args["cuda"]: model = model.cuda() criterion = criterion.cuda() params = list(model.parameters()) + list(criterion.parameters()) return model, criterion, params
total_loss += loss.item() return total_loss/len(data_loader) def train_and_eval(model, train_data_loader, valid_dataloader, test_datalaoder, num_epochs, criterion, optimizer, log_interval): for e in range(1, num_epochs+1): print('-'*89) print("epoch : ", e) train(model, train_dataloader, criterion, optimizer, log_interval) valid_loss = evaluate(model, valid_dataloader) print("valid loss : ", valid_loss) print('-'*89) test_loss = evaluate(model, test_datalaoder) print("test loss : ", test_loss) if __name__ == '__main__': recipe_data_path = os.path.join('recipe_data', 'data_with_ingredients') train_dataloader , valid_dataloader, test_datalaoder, ingred_vocab_size, recipe_vocab_size= get_data_loader(recipe_data_path, args) corpus_combined = load_combined_corpus(data_combined_path) model_type, _, _ = get_rnn_model(corpus_combined, lm_args) model_entity_composite, _, _ = get_double_input_rnn_model(corpus_combined, lm_args) model_combined, _, _ = get_combined_model(corpus_combined, model_entity_composite, model_type, lm_args) model = get_encoder_decoder_model(ingred_vocab_size, recipe_vocab_size, model_combined, args) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr = 0.001) train_and_eval(model, train_dataloader , valid_dataloader, test_datalaoder, args['num_epochs'], criterion, optimizer, args['log_interval'])
import torch from data_utils import load_hierarchical_corpus from main import args, data_hierarchical_path, model_type_save_path, model_type_to_subtype_save_path, model_subtype_to_word_save_path, model_hierarchical_save_path from rnn_model.build_model import get_model as get_rnn_model from double_input_rnn_model.build_model import get_model as get_double_input_rnn_model from hierarchical_model.build_model import get_model as get_hierarchical_model from rnn_model.utils import model_load from beam_search import beam_search, idx2word device = "cuda:0" if args["cuda"] else "cpu" with torch.no_grad(): corpus_hierarchical = load_hierarchical_corpus(data_hierarchical_path) model_type, _, _ = get_rnn_model(corpus_hierarchical, args) model_type_to_subtype, _, _ = get_double_input_rnn_model(corpus_hierarchical, args) model_subtype_to_word, _, _ = get_double_input_rnn_model(corpus_hierarchical, args) model_hierarchical, _, _ = get_hierarchical_model(corpus_hierarchical, model_subtype_to_word, model_type_to_subtype, model_type, args) model_hierarchical_state_dict, _, _ = model_load(model_hierarchical_save_path, device) model_hierarchical.load_state_dict(model_hierarchical_state_dict) initial_sentence = 'season chicken with salt and pepper .' print() print("initial_sentence : ", initial_sentence.replace('_', ' ')) print() initial_sentence = initial_sentence.split() hidden_hierarchical = model_hierarchical.init_hidden(1) search = ' '.join([idx2word(word, corpus_hierarchical) for word in beam_search(model_hierarchical, corpus_hierarchical, hidden_hierarchical, initial_sentence)[-1]]) search = search.replace('_', ' ') print(search)
torch.manual_seed(args["seed"]) if torch.cuda.is_available(): if not args["cuda"]: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) else: torch.cuda.set_device(0) torch.cuda.manual_seed(args["seed"]) corpus_combined = load_combined_corpus(data_combined_path) # type model corpus_type = load_text_corpus(data_combined_path, 'data_ori', 'data_type', corpus_combined.dictionary) model_type, criterion_type, params_model_type = get_rnn_model( corpus_type, args) if args["optimizer"] == "sgd": optimizer_model_type = torch.optim.SGD(params_model_type, lr=args["lr"], weight_decay=args["wdecay"]) else: optimizer_model_type = torch.optim.Adam(params_model_type, lr=args["lr"], weight_decay=args["wdecay"]) if not os.path.exists(model_type_save_path): train_and_eval_rnn_model(model_type, corpus_type, optimizer_model_type, criterion_type, params_model_type, args, model_type_save_path)