示例#1
0
def main(args: argparse.Namespace):
    # Load input data
    with open(args.train_metadata, 'r') as f:
        train_posts = json.load(f)

    with open(args.val_metadata, 'r') as f:
        val_posts = json.load(f)

    # Load labels
    labels = {}
    with open(args.label_intent, 'r') as f:
        intent_labels = json.load(f)
        labels['intent'] = {}
        for label in intent_labels:
            labels['intent'][label] = len(labels['intent'])

    with open(args.label_semiotic, 'r') as f:
        semiotic_labels = json.load(f)
        labels['semiotic'] = {}
        for label in semiotic_labels:
            labels['semiotic'][label] = len(labels['semiotic'])

    with open(args.label_contextual, 'r') as f:
        contextual_labels = json.load(f)
        labels['contextual'] = {}
        for label in contextual_labels:
            labels['contextual'][label] = len(labels['contextual'])

    # Build dictionary from training set
    train_captions = []
    for post in train_posts:
        train_captions.append(post['orig_caption'])
    dictionary = Dictionary(tokenizer_method="TreebankWordTokenizer")
    dictionary.build_dictionary_from_captions(train_captions)

    # Set up torch device
    if 'cuda' in args.device and torch.cuda.is_available():
        device = torch.device(args.device)
        kwargs = {'pin_memory': True}
    else:
        device = torch.device('cpu')
        kwargs = {}

    # Set up number of workers
    num_workers = min(multiprocessing.cpu_count(), args.num_workers)

    # Set up data loaders differently based on the task
    # TODO: Extend to ELMo + word2vec etc.
    if args.type == 'image_only':
        train_dataset = ImageOnlyDataset(train_posts, labels)
        val_dataset = ImageOnlyDataset(val_posts, labels)
        train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                        batch_size=args.batch_size,
                                                        shuffle=args.shuffle,
                                                        num_workers=num_workers,
                                                        collate_fn=collate_fn_pad_image_only,
                                                        **kwargs)
        val_data_loader = torch.utils.data.DataLoader(val_dataset,
                                                    batch_size=1,
                                                    num_workers=num_workers,
                                                    collate_fn=collate_fn_pad_image_only,
                                                    **kwargs)
    elif args.type == 'image_text':
        train_dataset = ImageTextDataset(train_posts, labels, dictionary)
        val_dataset = ImageTextDataset(val_posts, labels, dictionary)
        train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                        batch_size=args.batch_size,
                                                        shuffle=args.shuffle,
                                                        num_workers=num_workers,
                                                        collate_fn=collate_fn_pad_image_text,
                                                        **kwargs)
        val_data_loader = torch.utils.data.DataLoader(val_dataset,
                                                    batch_size=1,
                                                    num_workers=num_workers,
                                                    collate_fn=collate_fn_pad_image_text,
                                                    **kwargs)
    elif args.type == 'text_only':
        train_dataset = TextOnlyDataset(train_posts, labels, dictionary)
        val_dataset = TextOnlyDataset(val_posts, labels, dictionary)
        train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                        batch_size=args.batch_size,
                                                        shuffle=args.shuffle,
                                                        num_workers=num_workers,
                                                        collate_fn=collate_fn_pad_text_only,
                                                        **kwargs)
        val_data_loader = torch.utils.data.DataLoader(val_dataset,
                                                    batch_size=1,
                                                    num_workers=num_workers,
                                                    collate_fn=collate_fn_pad_text_only,
                                                    **kwargs)

    # Set up the model
    model = Model(vocab_size=dictionary.size()).to(device)

    # Set up an optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_scheduler_step_size, gamma=args.lr_scheduler_gamma) # decay by 0.1 every 15 epochs

    # Set up loss function
    loss_fn = torch.nn.CrossEntropyLoss()

    # Setup tensorboard
    if args.tensorboard:
        writer = tensorboard.SummaryWriter(log_dir=args.log_dir + "/" + args.name, flush_secs=1)
    else:
        writer = None

    # Training loop
    if args.classification == 'intent':
        keys = ['intent']
    elif args.classification == 'semiotic':
        keys = ['semiotic']
    elif args.classification == 'contextual':
        keys = ['contextual']
    elif args.classification == 'all':
        keys = ['intent', 'semiotic', 'contextual']
    else:
        raise ValueError("args.classification doesn't exist.")
    best_auc_ovr = 0.0
    best_auc_ovo = 0.0
    best_acc = 0.0
    best_model = None
    best_optimizer = None
    best_scheduler = None
    for epoch in range(args.epochs):
        for mode in ["train", "eval"]:
            # Set up a progress bar
            if mode == "train":
                pbar = tqdm.tqdm(enumerate(train_data_loader), total=len(train_data_loader))
                model.train()
            else:
                pbar = tqdm.tqdm(enumerate(val_data_loader), total=len(val_data_loader))
                model.eval()

            total_loss = 0
            label = dict.fromkeys(keys, np.array([], dtype=np.int))
            pred = dict.fromkeys(keys, None)
            for _, batch in pbar:
                if 'caption' not in batch:
                    caption_data = None
                else:
                    caption_data = batch['caption'].to(device)
                if 'image' not in batch:
                    image_data = None
                else:
                    image_data = batch['image'].to(device)
                label_batch = {}
                for key in keys:
                    label_batch[key] = batch['label'][key].to(device)
                    
                if mode == "train":
                    model.zero_grad()

                pred_batch = model(image_data, caption_data)
                
                for key in keys:
                    label[key] = np.concatenate((label[key], batch['label'][key].cpu().numpy()))
                    x = pred_batch[key].detach().cpu().numpy()
                    x_max = np.max(x, axis=1).reshape(-1, 1)
                    z = np.exp(x - x_max)
                    prediction_scores = z / np.sum(z, axis=1).reshape(-1, 1)
                    if pred[key] is not None:
                        pred[key] = np.vstack((pred[key], prediction_scores))
                    else:
                        pred[key] = prediction_scores
                       
                loss_batch = {}
                loss = None
                for key in keys:
                    loss_batch[key] = loss_fn(pred_batch[key], label_batch[key])
                    if loss is None:
                        loss = loss_batch[key]
                    else:
                        loss += loss_bath[key] 

                total_loss += loss.item()

                if mode == "train":
                    loss.backward()
                    optimizer.step()

            # Terminate the progress bar
            pbar.close()
            
            # Update lr scheduler
            if mode == "train":
                scheduler.step()

            for key in keys:
                auc_score_ovr = roc_auc_score(label[key], pred[key], multi_class='ovr') # pylint: disable-all
                auc_score_ovo = roc_auc_score(label[key], pred[key], multi_class='ovo') # pylint: disable-all
                accuracy = accuracy_score(label[key], np.argmax(pred[key], axis=1))
                print("[{} - {}] [AUC-OVR={:.3f}, AUC-OVO={:.3f}, ACC={:.3f}]".format(mode, key, auc_score_ovr, auc_score_ovo, accuracy))
                
                if mode == "eval":
                    best_auc_ovr = max(best_auc_ovr, auc_score_ovr)
                    best_auc_ovo = max(best_auc_ovo, auc_score_ovo)
                    best_acc = max(best_acc, accuracy)
                    best_model = model
                    best_optimizer = optimizer
                    best_scheduler = scheduler
                
                if writer:
                    writer.add_scalar('AUC-OVR/{}-{}'.format(mode, key), auc_score_ovr, epoch)
                    writer.add_scalar('AUC-OVO/{}-{}'.format(mode, key), auc_score_ovo, epoch)
                    writer.add_scalar('ACC/{}-{}'.format(mode, key), accuracy, epoch)
                    writer.flush()

            if writer:
                writer.add_scalar('Loss/{}'.format(mode), total_loss, epoch)
                writer.flush()

            print("[{}] Epoch {}: Loss = {}".format(mode, epoch, total_loss))

    hparam_dict = {
        'train_split': args.train_metadata,
        'val_split': args.val_metadata,
        'lr': args.lr,
        'epochs': args.epochs,
        'batch_size': args.batch_size,
        'num_workers': args.num_workers,
        'shuffle': args.shuffle,
        'lr_scheduler_gamma': args.lr_scheduler_gamma,
        'lr_scheduler_step_size': args.lr_scheduler_step_size,
    }
    metric_dict = {
        'AUC-OVR': best_auc_ovr,
        'AUC-OVO': best_auc_ovo,
        'ACC': best_acc
    }

    if writer:
        writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict)
        writer.flush()
    
    Path(args.output_dir).mkdir(exist_ok=True)
    torch.save({
        'hparam_dict': hparam_dict,
        'metric_dict': metric_dict,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }, Path(args.output_dir) / '{}.pt'.format(args.name))
示例#2
0
def main():
    parser = argparse.ArgumentParser(
        description='Train a neural machine translation model')

    # Training corpus
    corpora_group = parser.add_argument_group(
        'training corpora',
        'Corpora related arguments; specify either monolingual or parallel training corpora (or both)'
    )
    corpora_group.add_argument('--src_path',
                               help='the source language monolingual corpus')
    corpora_group.add_argument('--trg_path',
                               help='the target language monolingual corpus')
    corpora_group.add_argument(
        '--max_sentence_length',
        type=int,
        default=90,
        help='the maximum sentence length for training (defaults to 50)')

    # Embeddings/vocabulary
    embedding_group = parser.add_argument_group(
        'embeddings',
        'Embedding related arguments; either give pre-trained cross-lingual embeddings, or a vocabulary and embedding dimensionality to randomly initialize them'
    )
    embedding_group.add_argument('--src_vocabulary',
                                 help='the source language vocabulary')
    embedding_group.add_argument('--trg_vocabulary',
                                 help='the target language vocabulary')
    embedding_group.add_argument('--embedding_size',
                                 type=int,
                                 default=0,
                                 help='the word embedding size')

    # Architecture
    architecture_group = parser.add_argument_group(
        'architecture', 'Architecture related arguments')
    architecture_group.add_argument(
        '--layers',
        type=int,
        default=2,
        help='the number of encoder/decoder layers (defaults to 2)')
    architecture_group.add_argument(
        '--enc_hid_dim',
        type=int,
        default=512,
        help='the number of dimensions for the hidden layer (defaults to 600)')
    architecture_group.add_argument(
        '--dec_hid_dim',
        type=int,
        default=512,
        help='the number of dimensions for the hidden layer (defaults to 600)')

    # Optimization
    optimization_group = parser.add_argument_group(
        'optimization', 'Optimization related arguments')
    optimization_group.add_argument('--batch_size',
                                    type=int,
                                    default=128,
                                    help='the batch size (defaults to 50)')
    optimization_group.add_argument(
        '--learning_rate',
        type=float,
        default=0.0002,
        help='the global learning rate (defaults to 0.0002)')
    optimization_group.add_argument(
        '--dropout',
        metavar='PROB',
        type=float,
        default=0.4,
        help='dropout probability for the encoder/decoder (defaults to 0.3)')
    optimization_group.add_argument(
        '--param_init',
        metavar='RANGE',
        type=float,
        default=0.1,
        help=
        'uniform initialization in the specified range (defaults to 0.1,  0 for module specific default initialization)'
    )
    optimization_group.add_argument(
        '--iterations',
        type=int,
        default=50,
        help='the number of training iterations (defaults to 300000)')
    # Model saving
    saving_group = parser.add_argument_group(
        'model saving', 'Arguments for saving the trained model')
    saving_group.add_argument('--save_path',
                              metavar='PREFIX',
                              help='save models with the given prefix')
    saving_group.add_argument('--save_interval',
                              type=int,
                              default=0,
                              help='save intermediate models at this interval')
    saving_group.add_argument('--model_init_path', help='model init path')

    # Logging/validation
    logging_group = parser.add_argument_group(
        'logging', 'Logging and validation arguments')
    logging_group.add_argument('--log_interval',
                               type=int,
                               default=1000,
                               help='log at this interval (defaults to 1000)')
    logging_group.add_argument('--validate_batch_size',
                               type=int,
                               default=1,
                               help='the batch size (defaults to 50)')
    corpora_group.add_argument('--inference_output',
                               help='the source language monolingual corpus')
    corpora_group.add_argument('--validation_src_path',
                               help='the source language monolingual corpus')
    corpora_group.add_argument('--validation_trg_path',
                               help='the source language monolingual corpus')

    # Other
    parser.add_argument(
        '--encoding',
        default='utf-8',
        help='the character encoding for input/output (defaults to utf-8)')
    parser.add_argument('--cuda',
                        default=False,
                        action='store_true',
                        help='use cuda')
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument("--type",
                        type=str,
                        default='train',
                        help="type: train/inference/debug")

    args = parser.parse_args()
    print(args)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    src_dictionary = Dictionary(
        [word.strip() for word in open(args.src_vocabulary).readlines()])
    trg_dictionary = Dictionary(
        [word.strip() for word in open(args.trg_vocabulary).readlines()])

    def init_weights(m):
        for name, param in m.named_parameters():
            if 'weight' in name:
                nn.init.normal_(param.data, mean=0, std=0.01)
            else:
                nn.init.constant_(param.data, 0)

    if not args.model_init_path:
        attn = Attention(args.enc_hid_dim, args.dec_hid_dim)
        enc = Encoder(src_dictionary.size(), args.embedding_size,
                      args.enc_hid_dim, args.dec_hid_dim, args.dropout,
                      src_dictionary.PAD)
        dec = Decoder(trg_dictionary.size(), args.embedding_size,
                      args.enc_hid_dim, args.dec_hid_dim, args.dropout, attn)
        s2s = Seq2Seq(enc, dec, src_dictionary.PAD, device)
        parallel_model = Parser(src_dictionary, trg_dictionary, s2s, device)
        parallel_model.apply(init_weights)

    else:
        print(f"load init model from {args.model_init_path}")
        parallel_model = torch.load(args.model_init_path)

    parallel_model = parallel_model.to(device)

    if args.type == TEST:
        test_dataset = treeDataset(args.validation_src_path,
                                   args.validation_trg_path)
        test_dataloader = DataLoader(test_dataset,
                                     shuffle=False,
                                     batch_size=args.validate_batch_size,
                                     collate_fn=collate_fn)
        hit, total, acc = evaluate_iter_loss2(parallel_model, test_dataloader,
                                              src_dictionary, trg_dictionary,
                                              device)
        print(f'hit: {hit: d} |  total: {total: d} | acc: {acc: f}',
              flush=True)

    elif args.type == INFERENCE:
        test_dataset = customDataset(args.validation_src_path,
                                     args.validation_trg_path)
        test_dataloader = DataLoader(test_dataset,
                                     shuffle=False,
                                     batch_size=args.validate_batch_size)
        hit, total, acc = evaluate_iter_acc(parallel_model, test_dataloader,
                                            src_dictionary, trg_dictionary,
                                            device, args.inference_output)
        print(f'hit: {hit: d} |  total: {total: d} | acc: {acc: f}',
              flush=True)
    elif args.type == DEBUG:
        test_dataset = treeDataset(args.validation_src_path,
                                   args.validation_trg_path)
        test_dataloader = DataLoader(test_dataset,
                                     shuffle=False,
                                     batch_size=args.validate_batch_size,
                                     collate_fn=collate_fn)
        hit, total, acc = debug_iter(parallel_model, test_dataloader,
                                     src_dictionary, trg_dictionary, device)
        print(f'hit: {hit: d} |  total: {total: d} | acc: {acc: f}',
              flush=True)

    else:
        train_dataset = treeDataset(args.src_path, args.trg_path)
        train_dataloader = DataLoader(train_dataset,
                                      shuffle=True,
                                      batch_size=args.batch_size,
                                      collate_fn=collate_fn)
        test_dataset = treeDataset(args.validation_src_path,
                                   args.validation_trg_path)
        test_dataloader = DataLoader(test_dataset,
                                     shuffle=False,
                                     batch_size=args.validate_batch_size,
                                     collate_fn=collate_fn)

        train(src_dictionary, trg_dictionary, train_dataloader,
              test_dataloader, parallel_model, device, args)