def main(args: argparse.Namespace): # Load input data with open(args.train_metadata, 'r') as f: train_posts = json.load(f) with open(args.val_metadata, 'r') as f: val_posts = json.load(f) # Load labels labels = {} with open(args.label_intent, 'r') as f: intent_labels = json.load(f) labels['intent'] = {} for label in intent_labels: labels['intent'][label] = len(labels['intent']) with open(args.label_semiotic, 'r') as f: semiotic_labels = json.load(f) labels['semiotic'] = {} for label in semiotic_labels: labels['semiotic'][label] = len(labels['semiotic']) with open(args.label_contextual, 'r') as f: contextual_labels = json.load(f) labels['contextual'] = {} for label in contextual_labels: labels['contextual'][label] = len(labels['contextual']) # Build dictionary from training set train_captions = [] for post in train_posts: train_captions.append(post['orig_caption']) dictionary = Dictionary(tokenizer_method="TreebankWordTokenizer") dictionary.build_dictionary_from_captions(train_captions) # Set up torch device if 'cuda' in args.device and torch.cuda.is_available(): device = torch.device(args.device) kwargs = {'pin_memory': True} else: device = torch.device('cpu') kwargs = {} # Set up number of workers num_workers = min(multiprocessing.cpu_count(), args.num_workers) # Set up data loaders differently based on the task # TODO: Extend to ELMo + word2vec etc. if args.type == 'image_only': train_dataset = ImageOnlyDataset(train_posts, labels) val_dataset = ImageOnlyDataset(val_posts, labels) train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=num_workers, collate_fn=collate_fn_pad_image_only, **kwargs) val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=num_workers, collate_fn=collate_fn_pad_image_only, **kwargs) elif args.type == 'image_text': train_dataset = ImageTextDataset(train_posts, labels, dictionary) val_dataset = ImageTextDataset(val_posts, labels, dictionary) train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=num_workers, collate_fn=collate_fn_pad_image_text, **kwargs) val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=num_workers, collate_fn=collate_fn_pad_image_text, **kwargs) elif args.type == 'text_only': train_dataset = TextOnlyDataset(train_posts, labels, dictionary) val_dataset = TextOnlyDataset(val_posts, labels, dictionary) train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=num_workers, collate_fn=collate_fn_pad_text_only, **kwargs) val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=num_workers, collate_fn=collate_fn_pad_text_only, **kwargs) # Set up the model model = Model(vocab_size=dictionary.size()).to(device) # Set up an optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_scheduler_step_size, gamma=args.lr_scheduler_gamma) # decay by 0.1 every 15 epochs # Set up loss function loss_fn = torch.nn.CrossEntropyLoss() # Setup tensorboard if args.tensorboard: writer = tensorboard.SummaryWriter(log_dir=args.log_dir + "/" + args.name, flush_secs=1) else: writer = None # Training loop if args.classification == 'intent': keys = ['intent'] elif args.classification == 'semiotic': keys = ['semiotic'] elif args.classification == 'contextual': keys = ['contextual'] elif args.classification == 'all': keys = ['intent', 'semiotic', 'contextual'] else: raise ValueError("args.classification doesn't exist.") best_auc_ovr = 0.0 best_auc_ovo = 0.0 best_acc = 0.0 best_model = None best_optimizer = None best_scheduler = None for epoch in range(args.epochs): for mode in ["train", "eval"]: # Set up a progress bar if mode == "train": pbar = tqdm.tqdm(enumerate(train_data_loader), total=len(train_data_loader)) model.train() else: pbar = tqdm.tqdm(enumerate(val_data_loader), total=len(val_data_loader)) model.eval() total_loss = 0 label = dict.fromkeys(keys, np.array([], dtype=np.int)) pred = dict.fromkeys(keys, None) for _, batch in pbar: if 'caption' not in batch: caption_data = None else: caption_data = batch['caption'].to(device) if 'image' not in batch: image_data = None else: image_data = batch['image'].to(device) label_batch = {} for key in keys: label_batch[key] = batch['label'][key].to(device) if mode == "train": model.zero_grad() pred_batch = model(image_data, caption_data) for key in keys: label[key] = np.concatenate((label[key], batch['label'][key].cpu().numpy())) x = pred_batch[key].detach().cpu().numpy() x_max = np.max(x, axis=1).reshape(-1, 1) z = np.exp(x - x_max) prediction_scores = z / np.sum(z, axis=1).reshape(-1, 1) if pred[key] is not None: pred[key] = np.vstack((pred[key], prediction_scores)) else: pred[key] = prediction_scores loss_batch = {} loss = None for key in keys: loss_batch[key] = loss_fn(pred_batch[key], label_batch[key]) if loss is None: loss = loss_batch[key] else: loss += loss_bath[key] total_loss += loss.item() if mode == "train": loss.backward() optimizer.step() # Terminate the progress bar pbar.close() # Update lr scheduler if mode == "train": scheduler.step() for key in keys: auc_score_ovr = roc_auc_score(label[key], pred[key], multi_class='ovr') # pylint: disable-all auc_score_ovo = roc_auc_score(label[key], pred[key], multi_class='ovo') # pylint: disable-all accuracy = accuracy_score(label[key], np.argmax(pred[key], axis=1)) print("[{} - {}] [AUC-OVR={:.3f}, AUC-OVO={:.3f}, ACC={:.3f}]".format(mode, key, auc_score_ovr, auc_score_ovo, accuracy)) if mode == "eval": best_auc_ovr = max(best_auc_ovr, auc_score_ovr) best_auc_ovo = max(best_auc_ovo, auc_score_ovo) best_acc = max(best_acc, accuracy) best_model = model best_optimizer = optimizer best_scheduler = scheduler if writer: writer.add_scalar('AUC-OVR/{}-{}'.format(mode, key), auc_score_ovr, epoch) writer.add_scalar('AUC-OVO/{}-{}'.format(mode, key), auc_score_ovo, epoch) writer.add_scalar('ACC/{}-{}'.format(mode, key), accuracy, epoch) writer.flush() if writer: writer.add_scalar('Loss/{}'.format(mode), total_loss, epoch) writer.flush() print("[{}] Epoch {}: Loss = {}".format(mode, epoch, total_loss)) hparam_dict = { 'train_split': args.train_metadata, 'val_split': args.val_metadata, 'lr': args.lr, 'epochs': args.epochs, 'batch_size': args.batch_size, 'num_workers': args.num_workers, 'shuffle': args.shuffle, 'lr_scheduler_gamma': args.lr_scheduler_gamma, 'lr_scheduler_step_size': args.lr_scheduler_step_size, } metric_dict = { 'AUC-OVR': best_auc_ovr, 'AUC-OVO': best_auc_ovo, 'ACC': best_acc } if writer: writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict) writer.flush() Path(args.output_dir).mkdir(exist_ok=True) torch.save({ 'hparam_dict': hparam_dict, 'metric_dict': metric_dict, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), }, Path(args.output_dir) / '{}.pt'.format(args.name))
def main(): parser = argparse.ArgumentParser( description='Train a neural machine translation model') # Training corpus corpora_group = parser.add_argument_group( 'training corpora', 'Corpora related arguments; specify either monolingual or parallel training corpora (or both)' ) corpora_group.add_argument('--src_path', help='the source language monolingual corpus') corpora_group.add_argument('--trg_path', help='the target language monolingual corpus') corpora_group.add_argument( '--max_sentence_length', type=int, default=90, help='the maximum sentence length for training (defaults to 50)') # Embeddings/vocabulary embedding_group = parser.add_argument_group( 'embeddings', 'Embedding related arguments; either give pre-trained cross-lingual embeddings, or a vocabulary and embedding dimensionality to randomly initialize them' ) embedding_group.add_argument('--src_vocabulary', help='the source language vocabulary') embedding_group.add_argument('--trg_vocabulary', help='the target language vocabulary') embedding_group.add_argument('--embedding_size', type=int, default=0, help='the word embedding size') # Architecture architecture_group = parser.add_argument_group( 'architecture', 'Architecture related arguments') architecture_group.add_argument( '--layers', type=int, default=2, help='the number of encoder/decoder layers (defaults to 2)') architecture_group.add_argument( '--enc_hid_dim', type=int, default=512, help='the number of dimensions for the hidden layer (defaults to 600)') architecture_group.add_argument( '--dec_hid_dim', type=int, default=512, help='the number of dimensions for the hidden layer (defaults to 600)') # Optimization optimization_group = parser.add_argument_group( 'optimization', 'Optimization related arguments') optimization_group.add_argument('--batch_size', type=int, default=128, help='the batch size (defaults to 50)') optimization_group.add_argument( '--learning_rate', type=float, default=0.0002, help='the global learning rate (defaults to 0.0002)') optimization_group.add_argument( '--dropout', metavar='PROB', type=float, default=0.4, help='dropout probability for the encoder/decoder (defaults to 0.3)') optimization_group.add_argument( '--param_init', metavar='RANGE', type=float, default=0.1, help= 'uniform initialization in the specified range (defaults to 0.1, 0 for module specific default initialization)' ) optimization_group.add_argument( '--iterations', type=int, default=50, help='the number of training iterations (defaults to 300000)') # Model saving saving_group = parser.add_argument_group( 'model saving', 'Arguments for saving the trained model') saving_group.add_argument('--save_path', metavar='PREFIX', help='save models with the given prefix') saving_group.add_argument('--save_interval', type=int, default=0, help='save intermediate models at this interval') saving_group.add_argument('--model_init_path', help='model init path') # Logging/validation logging_group = parser.add_argument_group( 'logging', 'Logging and validation arguments') logging_group.add_argument('--log_interval', type=int, default=1000, help='log at this interval (defaults to 1000)') logging_group.add_argument('--validate_batch_size', type=int, default=1, help='the batch size (defaults to 50)') corpora_group.add_argument('--inference_output', help='the source language monolingual corpus') corpora_group.add_argument('--validation_src_path', help='the source language monolingual corpus') corpora_group.add_argument('--validation_trg_path', help='the source language monolingual corpus') # Other parser.add_argument( '--encoding', default='utf-8', help='the character encoding for input/output (defaults to utf-8)') parser.add_argument('--cuda', default=False, action='store_true', help='use cuda') parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument("--type", type=str, default='train', help="type: train/inference/debug") args = parser.parse_args() print(args) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') src_dictionary = Dictionary( [word.strip() for word in open(args.src_vocabulary).readlines()]) trg_dictionary = Dictionary( [word.strip() for word in open(args.trg_vocabulary).readlines()]) def init_weights(m): for name, param in m.named_parameters(): if 'weight' in name: nn.init.normal_(param.data, mean=0, std=0.01) else: nn.init.constant_(param.data, 0) if not args.model_init_path: attn = Attention(args.enc_hid_dim, args.dec_hid_dim) enc = Encoder(src_dictionary.size(), args.embedding_size, args.enc_hid_dim, args.dec_hid_dim, args.dropout, src_dictionary.PAD) dec = Decoder(trg_dictionary.size(), args.embedding_size, args.enc_hid_dim, args.dec_hid_dim, args.dropout, attn) s2s = Seq2Seq(enc, dec, src_dictionary.PAD, device) parallel_model = Parser(src_dictionary, trg_dictionary, s2s, device) parallel_model.apply(init_weights) else: print(f"load init model from {args.model_init_path}") parallel_model = torch.load(args.model_init_path) parallel_model = parallel_model.to(device) if args.type == TEST: test_dataset = treeDataset(args.validation_src_path, args.validation_trg_path) test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.validate_batch_size, collate_fn=collate_fn) hit, total, acc = evaluate_iter_loss2(parallel_model, test_dataloader, src_dictionary, trg_dictionary, device) print(f'hit: {hit: d} | total: {total: d} | acc: {acc: f}', flush=True) elif args.type == INFERENCE: test_dataset = customDataset(args.validation_src_path, args.validation_trg_path) test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.validate_batch_size) hit, total, acc = evaluate_iter_acc(parallel_model, test_dataloader, src_dictionary, trg_dictionary, device, args.inference_output) print(f'hit: {hit: d} | total: {total: d} | acc: {acc: f}', flush=True) elif args.type == DEBUG: test_dataset = treeDataset(args.validation_src_path, args.validation_trg_path) test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.validate_batch_size, collate_fn=collate_fn) hit, total, acc = debug_iter(parallel_model, test_dataloader, src_dictionary, trg_dictionary, device) print(f'hit: {hit: d} | total: {total: d} | acc: {acc: f}', flush=True) else: train_dataset = treeDataset(args.src_path, args.trg_path) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=collate_fn) test_dataset = treeDataset(args.validation_src_path, args.validation_trg_path) test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.validate_batch_size, collate_fn=collate_fn) train(src_dictionary, trg_dictionary, train_dataloader, test_dataloader, parallel_model, device, args)