Exemplo n.º 1
0
def train(sp):
    dataset = data_loader.load_processed_data(args)
    train_data = dataset['train']
    print('{} training examples loaded'.format(len(train_data)))
    dev_data = dataset['dev']
    print('{} dev examples loaded'.format(len(dev_data)))

    if args.xavier_initialization:
        ops.initialize_module(sp.mdl, 'xavier')
    else:
        raise NotImplementedError

    sp.schema_graphs = dataset['schema']
    if args.checkpoint_path is not None:
        sp.load_checkpoint(args.checkpoint_path)

    if args.test:
        train_data = train_data + dev_data

    sp.run_train(train_data, dev_data)
def train(train_data, dev_data):
    # Model
    model_dir = get_model_dir(args)
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    trans_checker = TranslatabilityChecker(args)
    trans_checker.cuda()
    ops.initialize_module(trans_checker, 'xavier')

    wandb.init(project='translatability-prediction', name=get_wandb_tag(args))
    wandb.watch(trans_checker)

    # Hyperparameters
    batch_size = 16
    num_peek_epochs = 1

    # Loss function
    # -100 is a dummy padding value since all output spans will be of length 2
    loss_fun = MaskedCrossEntropyLoss(-100)

    # Optimizer
    optimizer = optim.Adam([{
        'params': [
            p for n, p in trans_checker.named_parameters()
            if not 'trans_parameters' in n and p.requires_grad
        ]
    }, {
        'params': [
            p for n, p in trans_checker.named_parameters()
            if 'trans_parameters' in n and p.requires_grad
        ],
        'lr':
        args.bert_finetune_rate
    }],
                           lr=args.learning_rate)
    lr_scheduler = lrs.LinearScheduler(
        optimizer, [args.warmup_init_lr, args.warmup_init_ft_lr],
        [args.num_warmup_steps, args.num_warmup_steps], args.num_steps)

    best_dev_metrics = 0
    for epoch_id in range(args.num_epochs):
        random.shuffle(train_data)
        trans_checker.train()
        optimizer.zero_grad()

        epoch_losses = []

        for i in tqdm(range(0, len(train_data), batch_size)):
            wandb.log({
                'learning_rate/{}'.format(args.dataset_name):
                optimizer.param_groups[0]['lr']
            })
            wandb.log({
                'fine_tuning_rate/{}'.format(args.dataset_name):
                optimizer.param_groups[1]['lr']
            })
            mini_batch = train_data[i:i + batch_size]
            _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch],
                                          bu.pad_id)
            encoder_input_ids = ops.pad_batch(
                [exp.ptr_input_ids for exp in mini_batch], bu.pad_id)
            target_span_ids, _ = ops.pad_batch(
                [exp.span_ids for exp in mini_batch], bu.pad_id)
            output = trans_checker(encoder_input_ids, text_masks)
            loss = loss_fun(output, target_span_ids)
            loss.backward()
            epoch_losses.append(float(loss))

            if args.grad_norm > 0:
                nn.utils.clip_grad_norm_(trans_checker.parameters(),
                                         args.grad_norm)
            lr_scheduler.step()
            optimizer.step()
            optimizer.zero_grad()

        if args.num_epochs % num_peek_epochs == 0:
            stdout_msg = 'Epoch {}: average training loss = {}'.format(
                epoch_id, np.mean(epoch_losses))
            print(stdout_msg)
            wandb.log({
                'cross_entropy_loss/{}'.format(args.dataset_name):
                np.mean(epoch_losses)
            })
            pred_spans = trans_checker.inference(dev_data)
            target_spans = [exp.span_ids for exp in dev_data]
            trans_acc = translatablity_eval(pred_spans, target_spans)
            print('Dev translatability accuracy = {}'.format(trans_acc))
            if trans_acc > best_dev_metrics:
                model_path = os.path.join(model_dir, 'model-best.tar')
                trans_checker.save_checkpoint(optimizer, lr_scheduler,
                                              model_path)
                best_dev_metrics = trans_acc

            span_acc, prec, recall, f1 = span_eval(pred_spans, target_spans)
            print('Dev span accuracy = {}'.format(span_acc))
            print('Dev span precision = {}'.format(prec))
            print('Dev span recall = {}'.format(recall))
            print('Dev span F1 = {}'.format(f1))
            wandb.log({
                'translatability_accuracy/{}'.format(args.dataset_name):
                trans_acc
            })
            wandb.log({'span_accuracy/{}'.format(args.dataset_name): span_acc})
            wandb.log({'span_f1/{}'.format(args.dataset_name): f1})
Exemplo n.º 3
0
def train(train_data, dev_data):
    # Model
    model_dir = get_model_dir(args)
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    trans_checker = TranslatabilityChecker(args)
    trans_checker.cuda()
    ops.initialize_module(trans_checker, 'xavier')

    # Hyperparameters
    batch_size = min(len(train_data), 12)
    num_peek_epochs = 1

    # Loss function
    loss_fun = nn.BCELoss()
    span_extract_pad_id = -100
    span_extract_loss_fun = MaskedCrossEntropyLoss(span_extract_pad_id)

    # Optimizer
    optimizer = optim.Adam(
        [{'params': [p for n, p in trans_checker.named_parameters() if
                     not 'trans_parameters' in n and p.requires_grad]},
         {'params': [p for n, p in trans_checker.named_parameters() if 'trans_parameters' in n and p.requires_grad],
          'lr': args.bert_finetune_rate}],
        lr=args.learning_rate)
    lr_scheduler = lrs.LinearScheduler(
        optimizer, [args.warmup_init_lr, args.warmup_init_ft_lr], [args.num_warmup_steps, args.num_warmup_steps],
        args.num_steps)

    best_dev_metrics = 0
    for epoch_id in range(args.num_epochs):
        random.shuffle(train_data)
        trans_checker.train()
        optimizer.zero_grad()

        epoch_losses = []

        for i in tqdm(range(0, len(train_data), batch_size)):
            mini_batch = train_data[i: i + batch_size]
            _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id)
            encoder_input_ids = ops.pad_batch([exp.ptr_input_ids for exp in mini_batch], bu.pad_id)
            target_ids = ops.int_var_cuda([1 if exp.span_ids[0] == 0 else 0 for exp in mini_batch])
            target_span_ids, _ = ops.pad_batch([exp.span_ids for exp in mini_batch], bu.pad_id)
            target_span_ids = target_span_ids * (1 - target_ids.unsqueeze(1)) + \
                              target_ids.unsqueeze(1).expand_as(target_span_ids) * span_extract_pad_id
            output, span_extract_output = trans_checker(encoder_input_ids, text_masks)
            loss = loss_fun(output, target_ids.unsqueeze(1).float())
            span_extract_loss = span_extract_loss_fun(span_extract_output, target_span_ids)
            loss += span_extract_loss
            loss.backward()
            epoch_losses.append(float(loss))

            if args.grad_norm > 0:
                nn.utils.clip_grad_norm_(trans_checker.parameters(), args.grad_norm)
            lr_scheduler.step()
            optimizer.step()
            optimizer.zero_grad()

        with torch.no_grad():
            if args.num_epochs % num_peek_epochs == 0:
                stdout_msg = 'Epoch {}: average training loss = {}'.format(epoch_id, np.mean(epoch_losses))
                print(stdout_msg)
                pred_trans, pred_spans = trans_checker.inference(dev_data)
                targets = [1 if exp.span_ids[0] == 0 else 0 for exp in dev_data]
                target_spans = [exp.span_ids for exp in dev_data]
                trans_acc = translatablity_eval(pred_trans, targets)
                print('Dev translatability accuracy = {}'.format(trans_acc))
                if trans_acc > best_dev_metrics:
                    model_path = os.path.join(model_dir, 'model-best.tar')
                    trans_checker.save_checkpoint(optimizer, lr_scheduler, model_path)
                    best_dev_metrics = trans_acc
                span_acc, prec, recall, f1 = span_eval(pred_spans, target_spans)
                print('Dev span accuracy = {}'.format(span_acc))
                print('Dev span precision = {}'.format(prec))
                print('Dev span recall = {}'.format(recall))
                print('Dev span F1 = {}'.format(f1))