Example #1
0
def train_model(
    train_source,
    train_target,
    dev_source,
    dev_target,
    experiment_directory,
    resume=False,
):
    # Prepare dataset
    train = Seq2SeqDataset.from_file(train_source, train_target)
    train.build_vocab(300, 6000)
    dev = Seq2SeqDataset.from_file(
        dev_source,
        dev_target,
        share_fields_from=train,
    )
    input_vocab = train.src_field.vocab
    output_vocab = train.tgt_field.vocab

    # Prepare loss
    weight = torch.ones(len(output_vocab))
    pad = output_vocab.stoi[train.tgt_field.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    seq2seq = None
    optimizer = None
    if not resume:
        seq2seq, optimizer, scheduler = initialize_model(
            train, input_vocab, output_vocab)

    # Train
    trainer = SupervisedTrainer(
        loss=loss,
        batch_size=32,
        checkpoint_every=50,
        print_every=10,
        experiment_directory=experiment_directory,
    )
    start = time.clock()
    try:
        seq2seq = trainer.train(
            seq2seq,
            train,
            n_epochs=10,
            dev_data=dev,
            optimizer=optimizer,
            teacher_forcing_ratio=0.5,
            resume=resume,
        )
    # Capture ^C
    except KeyboardInterrupt:
        pass
    end = time.clock() - start
    logging.info('Training time: %.2fs', end)

    return seq2seq, input_vocab, output_vocab
Example #2
0
    def __init__(self,
                 data_path,
                 model_save_path,
                 model_load_path,
                 hidden_size=32,
                 max_vocab=4000,
                 device='cuda'):
        self.src = SourceField()
        self.tgt = TargetField()
        self.max_length = 90
        self.data_path = data_path
        self.model_save_path = model_save_path
        self.model_load_path = model_load_path

        def len_filter(example):
            return len(example.src) <= self.max_length and len(
                example.tgt) <= self.max_length

        self.trainset = torchtext.data.TabularDataset(
            path=os.path.join(self.data_path, 'train'),
            format='tsv',
            fields=[('src', self.src), ('tgt', self.tgt)],
            filter_pred=len_filter)
        self.devset = torchtext.data.TabularDataset(path=os.path.join(
            self.data_path, 'eval'),
                                                    format='tsv',
                                                    fields=[('src', self.src),
                                                            ('tgt', self.tgt)],
                                                    filter_pred=len_filter)
        self.src.build_vocab(self.trainset, max_size=max_vocab)
        self.tgt.build_vocab(self.trainset, max_size=max_vocab)
        weight = torch.ones(len(self.tgt.vocab))
        pad = self.tgt.vocab.stoi[self.tgt.pad_token]
        self.loss = Perplexity(weight, pad)
        self.loss.cuda()
        self.optimizer = None
        self.hidden_size = hidden_size
        self.bidirectional = True
        encoder = EncoderRNN(len(self.src.vocab),
                             self.max_length,
                             self.hidden_size,
                             bidirectional=self.bidirectional,
                             variable_lengths=True)
        decoder = DecoderRNN(len(self.tgt.vocab),
                             self.max_length,
                             self.hidden_size *
                             2 if self.bidirectional else self.hidden_size,
                             dropout_p=0.2,
                             use_attention=True,
                             bidirectional=self.bidirectional,
                             eos_id=self.tgt.eos_id,
                             sos_id=self.tgt.sos_id)
        self.device = device
        self.seq2seq = Seq2seq(encoder, decoder).cuda()
        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
Example #3
0
    def test_perplexity(self):
        nll = NLLLoss()
        ppl = Perplexity()
        nll.eval_batch(self.outputs, self.batch)
        ppl.eval_batch(self.outputs, self.batch)

        nll_loss = nll.get_loss()
        ppl_loss = ppl.get_loss()

        self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
    def test_perplexity(self):
        nll = NLLLoss()
        ppl = Perplexity()
        for output, target in zip(self.outputs, self.targets):
            nll.eval_batch(output, target)
            ppl.eval_batch(output, target)

        nll_loss = nll.get_loss()
        ppl_loss = ppl.get_loss()

        self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
Example #5
0
def build_model(src, tgt, hidden_size, mini_batch_size, bidirectional, dropout,
                attention, init_value):
    EXPERIMENT.param("Hidden", hidden_size)
    EXPERIMENT.param("Bidirectional", bidirectional)
    EXPERIMENT.param("Dropout", dropout)
    EXPERIMENT.param("Attention", attention)
    EXPERIMENT.param("Mini-batch", mini_batch_size)
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    encoder = EncoderRNN(len(src.vocab),
                         MAX_LEN,
                         hidden_size,
                         rnn_cell="lstm",
                         bidirectional=bidirectional,
                         dropout_p=dropout,
                         variable_lengths=False)
    decoder = DecoderRNN(
        len(tgt.vocab),
        MAX_LEN,
        hidden_size,  # * 2 if bidirectional else hidden_size,
        rnn_cell="lstm",
        use_attention=attention,
        eos_id=tgt.eos_id,
        sos_id=tgt.sos_id)
    seq2seq = Seq2seq(encoder, decoder)
    using_cuda = False
    if torch.cuda.is_available():
        using_cuda = True
        encoder.cuda()
        decoder.cuda()
        seq2seq.cuda()
        loss.cuda()
    EXPERIMENT.param("CUDA", using_cuda)
    for param in seq2seq.parameters():
        param.data.uniform_(-init_value, init_value)

    trainer = SupervisedTrainer(loss=loss,
                                batch_size=mini_batch_size,
                                checkpoint_every=5000,
                                random_seed=42,
                                print_every=1000)
    return seq2seq, trainer
Example #6
0
def load_model_data_evaluator(expt_dir, model_name, data_path, batch_size=128):
    checkpoint_path = os.path.join(expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, model_name)
    checkpoint = Checkpoint.load(checkpoint_path)
    model = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab

    data_all, data_sml, data_med, data_lrg, fields_inp, src, tgt, src_adv, idx_field = load_data(data_path)

    src.vocab = input_vocab
    tgt.vocab = output_vocab
    src_adv.vocab = input_vocab

    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()
    evaluator = Evaluator(loss=loss, batch_size=batch_size)

    return model, data_all, data_sml, data_med, data_lrg, evaluator, fields_inp
Example #7
0
    def test_perplexity(self):
        num_class = 5
        num_batch = 10
        batch_size = 5

        outputs = [F.softmax(Variable(torch.randn(batch_size, num_class)))
                   for _ in range(num_batch)]
        targets = [Variable(torch.LongTensor([random.randint(0, num_class - 1)
                                              for _ in range(batch_size)]))
                   for _ in range(num_batch)]

        nll = NLLLoss()
        ppl = Perplexity()
        for output, target in zip(outputs, targets):
            nll.eval_batch(output, target)
            ppl.eval_batch(output, target)

        nll_loss = nll.get_loss()
        ppl_loss = ppl.get_loss()

        self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
Example #8
0
def train(args):
    train_ds, dev_ds, src_field, tgt_field = build_dataset(args)
    model = build_model(tgt_field, bidirectional=True)
    trainer = SupervisedTrainer(loss=Perplexity(),
                                batch_size=args.batch_size,
                                checkpoint_every=50,
                                expt_dir=args.expt_dir,
                                print_every=args.print_every)
    model = trainer.train(model=model,
                          data=train_ds,
                          num_epochs=args.num_epochs,
                          optimizer=None,
                          dev_data=dev_ds)
Example #9
0
def pretrain_generator(model, train, dev):
    # pre-train generator
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    optimizer = Optimizer(torch.optim.Adam(gen.parameters()), max_grad_norm=5)
    scheduler = StepLR(optimizer.optimizer, 1)
    optimizer.set_scheduler(scheduler)

    supervised = SupervisedTrainer(loss=loss,
                                   batch_size=32,
                                   random_seed=random_seed,
                                   expt_dir=expt_gen_dir)
    supervised.train(model,
                     train,
                     num_epochs=20,
                     dev_data=dev,
                     optimizer=optimizer,
                     teacher_forcing_ratio=0,
                     resume=resume)
Example #10
0
                         dropout_p=0.2,
                         use_attention=True)
    seq2seq = Seq2seq(encoder, decoder)

    if opt.resume:
        print("resuming training")
        latest_checkpoint = Checkpoint.get_latest_checkpoint(opt.expt_dir)
        seq2seq.load(latest_checkpoint)
    else:
        for param in seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)

    # Prepare loss
    weight = torch.ones(output_vocab.get_vocab_size())
    mask = output_vocab.MASK_token_id
    loss = Perplexity(weight, mask)

    if torch.cuda.is_available():
        seq2seq.cuda()
        loss.cuda()

    # train
    t = SupervisedTrainer(loss=loss,
                          batch_size=32,
                          checkpoint_every=50,
                          print_every=10,
                          expt_dir=opt.expt_dir)
    t.train(seq2seq,
            dataset,
            num_epochs=4,
            dev_data=dev_set,
Example #11
0
def train():
    src = SourceField(sequential=True,
                      tokenize=lambda x: [i for i in jieba.lcut(x)])
    tgt = TargetField(sequential=True,
                      tokenize=lambda x: [i for i in jieba.lcut(x)])
    max_len = 50

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='csv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='csv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)

    src.build_vocab(train, max_size=50000)
    tgt.build_vocab(train, max_size=50000)
    input_vocab = src.vocab
    output_vocab = tgt.vocab

    # NOTE: If the source field name and the target field name
    # are different from 'src' and 'tgt' respectively, they have
    # to be set explicitly before any training or inference
    # seq2seq.src_field_name = 'src'
    # seq2seq.tgt_field_name = 'tgt'

    # Prepare loss
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    seq2seq = None
    optimizer = None
    if not opt.resume:
        # Initialize model
        hidden_size = 128
        bidirectional = True
        encoder = EncoderRNN(len(src.vocab),
                             max_len,
                             hidden_size,
                             bidirectional=bidirectional,
                             variable_lengths=True)
        decoder = DecoderRNN(len(tgt.vocab),
                             max_len,
                             hidden_size * 2 if bidirectional else hidden_size,
                             dropout_p=0.2,
                             use_attention=True,
                             bidirectional=bidirectional,
                             eos_id=tgt.eos_id,
                             sos_id=tgt.sos_id)
        seq2seq = Seq2seq(encoder, decoder)
        if torch.cuda.is_available():
            seq2seq.cuda()

        for param in seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)

        # Optimizer and learning rate scheduler can be customized by
        # explicitly constructing the objects and pass to the trainer.
        #
        # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
        # scheduler = StepLR(optimizer.optimizer, 1)
        # optimizer.set_scheduler(scheduler)

    # train
    t = SupervisedTrainer(loss=loss,
                          batch_size=32,
                          checkpoint_every=50,
                          print_every=10,
                          expt_dir=opt.expt_dir)

    seq2seq = t.train(seq2seq,
                      train,
                      num_epochs=6,
                      dev_data=dev,
                      optimizer=optimizer,
                      teacher_forcing_ratio=0.5,
                      resume=opt.resume)
    predictor = Predictor(seq2seq, input_vocab, output_vocab)
Example #12
0
                    specials=replace_tokens)
    tgt.build_vocab(train, max_size=params['tgt_vocab_size'])
    # input_vocab = src.vocab
    # output_vocab = tgt.vocab

src_adv.vocab = src.vocab

logging.info('Indices of special replace tokens:\n')
for rep in replace_tokens:
    logging.info("%s, %d; " % (rep, src.vocab.stoi[rep]))
logging.info('\n')

# Prepare loss
weight = torch.ones(len(tgt.vocab))
pad = tgt.vocab.stoi[tgt.pad_token]
loss = Perplexity(weight, pad)
if torch.cuda.is_available():
    loss.cuda()

batch_adv_loss = Perplexity(weight, pad)
if torch.cuda.is_available():
    batch_adv_loss.cuda()

# seq2seq = None
optimizer = None
if not opt.resume:
    # Initialize model
    hidden_size = params['hidden_size']
    bidirectional = True
    encoder = EncoderRNN(len(src.vocab),
                         max_len,
 def test_perplexity_init(self):
     loss = Perplexity()
     self.assertEqual(loss.name, Perplexity._NAME)
Example #14
0
def main(option):
    random.seed(option.random_seed)
    torch.manual_seed(option.random_seed)

    LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level='INFO', stream=sys.stdout)

    glove = Glove(option.emb_file)
    logging.info('loaded embeddings from ' + option.emb_file)

    src_vocab = Vocab.build_from_glove(glove)
    tgt_vocab = Vocab.load(option.intent_vocab)

    train_dataset = load_intent_prediction_dataset(option.train_dataset,
                                                   src_vocab,
                                                   tgt_vocab,
                                                   device=option.device)
    dev_dataset = load_intent_prediction_dataset(option.dev_dataset,
                                                 src_vocab,
                                                 tgt_vocab,
                                                 device=option.device)

    train_data_loader = DataLoader(train_dataset,
                                   batch_size=option.batch_size,
                                   shuffle=True)
    dev_data_loader = DataLoader(dev_dataset,
                                 batch_size=len(dev_dataset),
                                 shuffle=False)

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    # Prepare loss
    weight = torch.ones(tgt_vocab_size)
    pad = tgt_vocab.stoi[tgt_vocab.pad_token]
    loss = Perplexity(weight, pad)
    loss.criterion.to(option.device)

    # Initialize model
    encoder = NeuralTensorNetwork(nn.Embedding(src_vocab_size, option.emb_dim),
                                  option.em_k)
    decoder = DecoderRNN(tgt_vocab_size,
                         option.im_max_len,
                         option.im_hidden_size,
                         use_attention=False,
                         bidirectional=False,
                         eos_id=tgt_vocab.stoi[tgt_vocab.eos_token],
                         sos_id=tgt_vocab.stoi[tgt_vocab.bos_token])
    encoder.to(option.device)
    decoder.to(option.device)

    init_model(encoder)
    init_model(decoder)

    encoder.embeddings.weight.data.copy_(torch.from_numpy(glove.embd).float())

    optimizer_params = [{
        'params': encoder.parameters()
    }, {
        'params': decoder.parameters()
    }]
    optimizer = Optimizer(optim.Adam(optimizer_params, lr=option.lr),
                          max_grad_norm=5)
    trainer = NTNTrainer(loss,
                         print_every=option.report_every,
                         device=option.device)
    encoder, decoder = trainer.train(
        encoder,
        decoder,
        optimizer,
        train_data_loader,
        num_epochs=option.epochs,
        dev_data_loader=dev_data_loader,
        teacher_forcing_ratio=option.im_teacher_forcing_ratio)

    predictor = NTNPredictor(encoder, decoder, src_vocab, tgt_vocab,
                             option.device)
    samples = [
        ("PersonX", "eventually told", "___"),
        ("PersonX", "tells", "PersonY 's tale"),
        ("PersonX", "always played", " ___"),
        ("PersonX", "would teach", "PersonY"),
        ("PersonX", "gets", "a ride"),
    ]
    for sample in samples:
        subj, verb, obj = sample
        subj = subj.lower().split(' ')
        verb = verb.lower().split(' ')
        obj = obj.lower().split(' ')
        print(sample, predictor.predict(subj, verb, obj))
Example #15
0
def apply_gradient_attack(data, model, input_vocab, replace_tokens, field_name,
                          opt):
    def convert_to_onehot(inp, vocab_size):
        return torch.zeros(inp.size(0), inp.size(1), vocab_size,
                           device=device).scatter_(2, inp.unsqueeze(2), 1.)

    batch_iterator = torchtext.data.BucketIterator(
        dataset=data,
        batch_size=opt.batch_size,
        sort=True,
        sort_within_batch=True,
        sort_key=lambda x: len(x.src),
        device=device,
        repeat=False)
    batch_generator = batch_iterator.__iter__()

    weight = torch.ones(len(tgt.vocab)).half()
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()
    model.train()

    d = {}

    for batch in tqdm.tqdm(batch_generator, total=len(batch_iterator)):
        indices = getattr(batch, 'index')
        input_variables, input_lengths = getattr(batch, field_name)
        target_variables = getattr(batch, 'tgt')

        # Do random attack if inputs are too long and will OOM under gradient attack
        if max(getattr(batch, field_name)[1]) > 250:
            rand_replacements = get_random_token_replacement(
                input_variables.cpu().numpy(), input_vocab,
                indices.cpu().numpy(), replace_tokens, opt.distinct)

            d.update(rand_replacements)
            continue

        # convert input_variables to one_hot
        input_onehot = Variable(convert_to_onehot(input_variables,
                                                  vocab_size=len(input_vocab)),
                                requires_grad=True).half()

        # Forward propagation
        decoder_outputs, decoder_hidden, other = model(input_onehot,
                                                       input_lengths,
                                                       target_variables,
                                                       already_one_hot=True)

        # print outputs for debugging
        # for i,output_seq_len in enumerate(other['length']):
        #	print(i,output_seq_len)
        #	tgt_id_seq = [other['sequence'][di][i].data[0] for di in range(output_seq_len)]
        #	tgt_seq = [output_vocab.itos[tok] for tok in tgt_id_seq]
        #	print(' '.join([x for x in tgt_seq if x not in ['<sos>','<eos>','<pad>']]), end=', ')
        #	gt = [output_vocab.itos[tok] for tok in target_variables[i]]
        #	print(' '.join([x for x in gt if x not in ['<sos>','<eos>','<pad>']]))

        # Get loss
        loss.reset()
        for step, step_output in enumerate(decoder_outputs):
            batch_size = target_variables.size(0)
            loss.eval_batch(step_output.contiguous().view(batch_size, -1),
                            target_variables[:, step + 1])
        # Backward propagation
        model.zero_grad()
        input_onehot.retain_grad()
        loss.backward(retain_graph=True)
        grads = input_onehot.grad
        del input_onehot

        best_replacements = get_best_token_replacement(
            input_variables.cpu().numpy(),
            grads.cpu().numpy(), input_vocab,
            indices.cpu().numpy(), replace_tokens, opt.distinct)

        d.update(best_replacements)

    return d
Example #16
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser()

    parser.add_argument('-data', required=True)

    parser.add_argument('-epoch', type=int, default=3)
    parser.add_argument('-batch_size', type=int, default=64)

    parser.add_argument('-d_model', type=int, default=1024)
    parser.add_argument('-n_layer', type=int, default=1)

    parser.add_argument('-dropout', type=float, default=0)

    parser.add_argument('-log', default=None)
    parser.add_argument('-save_model', default=None)
    parser.add_argument('-save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='best')

    parser.add_argument('-seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-teacher_forcing_ratio', type=float, default=0.5)

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model
    opt.log = opt.save_model

    random.seed(opt.seed)
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    if opt.cuda:
        torch.cuda.manual_seed_all(opt.seed)

    #========= Loading Dataset =========#
    data = torch.load(opt.data)
    opt.max_token_seq_len = data['settings'].max_token_seq_len

    training_data, validation_data = prepare_dataloaders(data, opt)

    opt.src_vocab_size = training_data.dataset.src_vocab_size
    opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size

    #========= Preparing Model =========#
    print(opt)
    device = torch.device('cuda' if opt.cuda else 'cpu')

    # model
    opt.bidirectional = True
    encoder = EncoderRNN(opt.src_vocab_size,
                         opt.max_token_seq_len,
                         opt.d_model,
                         bidirectional=opt.bidirectional,
                         variable_lengths=True)
    decoder = DecoderRNN(opt.tgt_vocab_size,
                         opt.max_token_seq_len,
                         opt.d_model * 2 if opt.bidirectional else opt.d_model,
                         n_layers=opt.n_layer,
                         dropout_p=opt.dropout,
                         use_attention=True,
                         bidirectional=opt.bidirectional,
                         eos_id=Constants.BOS,
                         sos_id=Constants.EOS)
    seq2seq = Seq2seq(encoder, decoder).to(device)
    for param in seq2seq.parameters():
        param.data.uniform_(-0.08, 0.08)

    seq2seq = nn.DataParallel(seq2seq)

    # loss
    weight = torch.ones(opt.tgt_vocab_size)
    pad = Constants.PAD
    loss = Perplexity(weight, pad)
    if opt.cuda:
        loss.cuda()

    # optimizer
    optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()),
                          max_grad_norm=5)

    train(seq2seq, training_data, validation_data, loss, optimizer, device,
          opt)
Example #17
0
def run_training(opt, default_data_dir, num_epochs=100):
    if opt.load_checkpoint is not None:
        logging.info("loading checkpoint from {}".format(
            os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)))
        checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)
        checkpoint = Checkpoint.load(checkpoint_path)
        seq2seq = checkpoint.model
        input_vocab = checkpoint.input_vocab
        output_vocab = checkpoint.output_vocab
    else:

        # Prepare dataset
        src = SourceField()
        tgt = TargetField()
        max_len = 50

        data_file = os.path.join(default_data_dir, opt.train_path, 'data.txt')

        logging.info("Starting new Training session on %s", data_file)

        def len_filter(example):
            return (len(example.src) <= max_len) and (len(example.tgt) <= max_len) \
                   and (len(example.src) > 0) and (len(example.tgt) > 0)

        train = torchtext.data.TabularDataset(
            path=data_file, format='json',
            fields={'src': ('src', src), 'tgt': ('tgt', tgt)},
            filter_pred=len_filter
        )

        dev = None
        if opt.no_dev is False:
            dev_data_file = os.path.join(default_data_dir, opt.train_path, 'dev-data.txt')
            dev = torchtext.data.TabularDataset(
                path=dev_data_file, format='json',
                fields={'src': ('src', src), 'tgt': ('tgt', tgt)},
                filter_pred=len_filter
            )

        src.build_vocab(train, max_size=50000)
        tgt.build_vocab(train, max_size=50000)
        input_vocab = src.vocab
        output_vocab = tgt.vocab

        # NOTE: If the source field name and the target field name
        # are different from 'src' and 'tgt' respectively, they have
        # to be set explicitly before any training or inference
        # seq2seq.src_field_name = 'src'
        # seq2seq.tgt_field_name = 'tgt'

        # Prepare loss
        weight = torch.ones(len(tgt.vocab))
        pad = tgt.vocab.stoi[tgt.pad_token]
        loss = Perplexity(weight, pad)
        if torch.cuda.is_available():
            logging.info("Yayyy We got CUDA!!!")
            loss.cuda()
        else:
            logging.info("No cuda available device found running on cpu")

        seq2seq = None
        optimizer = None
        if not opt.resume:
            hidden_size = 128
            decoder_hidden_size = hidden_size * 2
            logging.info("EncoderRNN Hidden Size: %s", hidden_size)
            logging.info("DecoderRNN Hidden Size: %s", decoder_hidden_size)
            bidirectional = True
            encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
                                 bidirectional=bidirectional,
                                 rnn_cell='lstm',
                                 variable_lengths=True)
            decoder = DecoderRNN(len(tgt.vocab), max_len, decoder_hidden_size,
                                 dropout_p=0, use_attention=True,
                                 bidirectional=bidirectional,
                                 rnn_cell='lstm',
                                 eos_id=tgt.eos_id, sos_id=tgt.sos_id)

            seq2seq = Seq2seq(encoder, decoder)
            if torch.cuda.is_available():
                seq2seq.cuda()

            for param in seq2seq.parameters():
                param.data.uniform_(-0.08, 0.08)

        # Optimizer and learning rate scheduler can be customized by
        # explicitly constructing the objects and pass to the trainer.

        optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
        scheduler = StepLR(optimizer.optimizer, 1)
        optimizer.set_scheduler(scheduler)

        # train

        num_epochs = num_epochs
        batch_size = 32
        checkpoint_every = num_epochs / 10
        print_every = num_epochs / 100

        properties = dict(batch_size=batch_size,
                          checkpoint_every=checkpoint_every,
                          print_every=print_every, expt_dir=opt.expt_dir,
                          num_epochs=num_epochs,
                          teacher_forcing_ratio=0.5,
                          resume=opt.resume)

        logging.info("Starting training with the following Properties %s", json.dumps(properties, indent=2))
        t = SupervisedTrainer(loss=loss, batch_size=num_epochs,
                              checkpoint_every=checkpoint_every,
                              print_every=print_every, expt_dir=opt.expt_dir)

        seq2seq = t.train(seq2seq, train,
                          num_epochs=num_epochs, dev_data=dev,
                          optimizer=optimizer,
                          teacher_forcing_ratio=0.5,
                          resume=opt.resume)

        evaluator = Evaluator(loss=loss, batch_size=batch_size)

        if opt.no_dev is False:
            dev_loss, accuracy = evaluator.evaluate(seq2seq, dev)
            logging.info("Dev Loss: %s", dev_loss)
            logging.info("Accuracy: %s", dev_loss)

    beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 4))

    predictor = Predictor(beam_search, input_vocab, output_vocab)
    while True:
        try:
            seq_str = raw_input("Type in a source sequence:")
            seq = seq_str.strip().split()
            results = predictor.predict_n(seq, n=3)
            for i, res in enumerate(results):
                print('option %s: %s\n', i + 1, res)
        except KeyboardInterrupt:
            logging.info("Bye Bye")
            exit(0)
Example #18
0

if __name__ == "__main__":
    # Prepare Datasets and Vocab
    src_vocab_list = VocabField.load_vocab(opt.src_vocab_file)
    tgt_vocab_list = VocabField.load_vocab(opt.tgt_vocab_file)
    src_vocab = VocabField(src_vocab_list, vocab_size=opt.src_vocab_size)
    tgt_vocab = VocabField(tgt_vocab_list,
                           vocab_size=opt.tgt_vocab_size,
                           sos_token="<SOS>",
                           eos_token="<EOS>")
    pad_id = tgt_vocab.word2idx[tgt_vocab.pad_token]

    # Prepare loss
    weight = torch.ones(len(tgt_vocab.vocab))
    loss = Perplexity(weight, pad_id)
    loss.to(device)

    # Initialize model
    encoder = EncoderRNN(len(src_vocab.vocab),
                         opt.max_src_length,
                         embedding_size=opt.embedding_size,
                         rnn_cell=opt.rnn_cell,
                         n_layers=opt.n_hidden_layer,
                         hidden_size=opt.hidden_size,
                         bidirectional=opt.bidirectional,
                         variable_lengths=False)

    decoder = DecoderRNN(len(tgt_vocab.vocab),
                         opt.max_tgt_length,
                         embedding_size=opt.embedding_size,
Example #19
0
class auto_seq2seq:
    def __init__(self,
                 data_path,
                 model_save_path,
                 model_load_path,
                 hidden_size=32,
                 max_vocab=4000,
                 device='cuda'):
        self.src = SourceField()
        self.tgt = TargetField()
        self.max_length = 90
        self.data_path = data_path
        self.model_save_path = model_save_path
        self.model_load_path = model_load_path

        def len_filter(example):
            return len(example.src) <= self.max_length and len(
                example.tgt) <= self.max_length

        self.trainset = torchtext.data.TabularDataset(
            path=os.path.join(self.data_path, 'train'),
            format='tsv',
            fields=[('src', self.src), ('tgt', self.tgt)],
            filter_pred=len_filter)
        self.devset = torchtext.data.TabularDataset(path=os.path.join(
            self.data_path, 'eval'),
                                                    format='tsv',
                                                    fields=[('src', self.src),
                                                            ('tgt', self.tgt)],
                                                    filter_pred=len_filter)
        self.src.build_vocab(self.trainset, max_size=max_vocab)
        self.tgt.build_vocab(self.trainset, max_size=max_vocab)
        weight = torch.ones(len(self.tgt.vocab))
        pad = self.tgt.vocab.stoi[self.tgt.pad_token]
        self.loss = Perplexity(weight, pad)
        self.loss.cuda()
        self.optimizer = None
        self.hidden_size = hidden_size
        self.bidirectional = True
        encoder = EncoderRNN(len(self.src.vocab),
                             self.max_length,
                             self.hidden_size,
                             bidirectional=self.bidirectional,
                             variable_lengths=True)
        decoder = DecoderRNN(len(self.tgt.vocab),
                             self.max_length,
                             self.hidden_size *
                             2 if self.bidirectional else self.hidden_size,
                             dropout_p=0.2,
                             use_attention=True,
                             bidirectional=self.bidirectional,
                             eos_id=self.tgt.eos_id,
                             sos_id=self.tgt.sos_id)
        self.device = device
        self.seq2seq = Seq2seq(encoder, decoder).cuda()
        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)

    def train(self, epoch=20, resume=False):
        t = SupervisedTrainer(loss=self.loss,
                              batch_size=96,
                              checkpoint_every=1000,
                              print_every=1000,
                              expt_dir=self.model_save_path)
        self.seq2seq = t.train(self.seq2seq,
                               self.trainset,
                               num_epochs=epoch,
                               dev_data=self.devset,
                               optimizer=self.optimizer,
                               teacher_forcing_ratio=0.5,
                               resume=resume)
Example #20
0
def offline_training(opt, traget_file_path):

    # Prepare dataset with torchtext
    src = SourceField(tokenize=treebank_tokenizer)
    tgt = TargetField(tokenize=treebank_tokenizer)

    def sample_filter(sample):
        """ sample example for future purpose"""
        return True

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='tsv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=sample_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=sample_filter)
    test = torchtext.data.TabularDataset(path=opt.dev_path,
                                         format='tsv',
                                         fields=[('src', src), ('tgt', tgt)],
                                         filter_pred=sample_filter)
    src.build_vocab(train, max_size=opt.src_vocab_size)
    tgt.build_vocab(train, max_size=opt.tgt_vocab_size)
    input_vocab = src.vocab
    output_vocab = tgt.vocab

    # NOTE: If the source field name and the target field name
    # are different from 'src' and 'tgt' respectively, they have
    # to be set explicitly before any training or inference
    # seq2seq.src_field_name = 'src'
    # seq2seq.tgt_field_name = 'tgt'

    # Prepare loss
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    if opt.loss == 'perplexity':
        loss = Perplexity(weight, pad)
    else:
        raise TypeError

    seq2seq = None
    optimizer = None
    if not opt.resume:
        # Initialize model
        encoder = EncoderRNN(vocab_size=len(src.vocab),
                             max_len=opt.max_length,
                             hidden_size=opt.hidden_size,
                             input_dropout_p=opt.intput_dropout_p,
                             dropout_p=opt.dropout_p,
                             n_layers=opt.n_layers,
                             bidirectional=opt.bidirectional,
                             rnn_cell=opt.rnn_cell,
                             variable_lengths=True,
                             embedding=input_vocab.vectors
                             if opt.use_pre_trained_embedding else None,
                             update_embedding=opt.update_embedding)
        decoder = DecoderRNN(vocab_size=len(tgt.vocab),
                             max_len=opt.max_length,
                             hidden_size=opt.hidden_size *
                             2 if opt.bidirectional else opt.hidden_size,
                             sos_id=tgt.sos_id,
                             eos_id=tgt.eos_id,
                             n_layers=opt.n_layers,
                             rnn_cell=opt.rnn_cell,
                             bidirectional=opt.bidirectional,
                             input_dropout_p=opt.input_dropout_p,
                             dropout_p=opt.dropout_p,
                             use_attention=opt.use_attention)
        seq2seq = Seq2seq(encoder=encoder, decoder=decoder)
        if opt.gpu >= 0 and torch.cuda.is_available():
            seq2seq.cuda()

        for param in seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
    # train
    trainer = SupervisedTrainer(loss=loss,
                                batch_size=opt.batch_size,
                                checkpoint_every=opt.checkpoint_every,
                                print_every=opt.print_every,
                                expt_dir=opt.expt_dir)
    seq2seq = trainer.train(model=seq2seq,
                            data=train,
                            num_epochs=opt.epochs,
                            resume=opt.resume,
                            dev_data=dev,
                            optimizer=optimizer,
                            teacher_forcing_ratio=opt.teacher_forcing_rate)
Example #21
0
    # trying to separate the feats and inputs
    # feats = [x for x in src.vocab.freqs if len(x) > 1]
    # example of getting multihot vector:
    # [1 if x in test_feats else 0 for x in feats]

    # NOTE: If the source field name and the target field name
    # are different from 'src' and 'tgt' respectively, they have
    # to be set explicitly before any training or inference
    # seq2seq.src_field_name = 'src'
    # seq2seq.tgt_field_name = 'tgt'

    # Prepare loss
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(DEVICE, weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    seq2seq = None
    optimizer = None
    if not opt.resume:
        # Initialize model
        hidden_size = config['encoder embed']
        # TODO is this ideal?
        feat_hidden_size = len(feats.vocab) // 2
        bidirectional = True
        encoder = EncoderRNN(
            len(src.vocab),
            feats.vocab,
            max_len,
Example #22
0
    # prepare dataset
    train_cap_lang, train_label_lang, train_tuples, dev_cap_lang, dev_label_lang, dev_tuples, \
        x_mean_std, y_mean_std, w_mean_std, r_mean_std = prepare_data(opt.train_path,
            opt.dev_path, opt.mean_std_path, opt.max_len, opt.min_len, ixtoword, wordtoix)

    weight = torch.ones(len(train_label_lang.word2index))
    for word in train_label_lang.word2index:
        if train_label_lang.word2count[word] == 0:
            continue
        index = train_label_lang.word2index[word]
        weight[index] = weight[index] * opt.count_smooth / float(
            math.pow(train_label_lang.word2count[word], 0.8))

    # Prepare loss
    pad = train_label_lang.word2index["<pad>"]
    lloss = Perplexity(weight, pad, opt.lamda1)
    bloss = BBLoss(opt.batch_size, opt.gmm_comp_num, opt.lamda2)
    if torch.cuda.is_available():
        lloss.cuda()
        bloss.cuda()

    print('train_label_lang.index2word:')
    for index in train_label_lang.index2word:
        print('{} : {} '.format(index, train_label_lang.index2word[index]))

    print('train_label_lang.word2count:')
    for word in train_label_lang.word2count:
        print('{} : {} '.format(word, train_label_lang.word2count[word]))

    hidden_size = opt.embedding_dim
    encoder = PreEncoderRNN(train_cap_lang.n_words, nhidden=opt.embedding_dim)
Example #23
0
    output_vocab = tgt.vocab

    # inputs = torchtext.Field(lower=True, include_lengths=True, batch_first=True)
    # inputs.build_vocab(src.vocab)
    src.vocab.load_vectors(wv_type='glove.840B', wv_dim=300)

    # NOTE: If the source field name and the target field name
    # are different from 'src' and 'tgt' respectively, they have
    # to be set explicitly before any training or inference
    # seq2seq.src_field_name = 'src'
    # seq2seq.tgt_field_name = 'tgt'

    # Prepare loss
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    seq2seq = None
    optimizer = None
    if not opt.resume:
        # Initialize model
        # hidden_size=128
        hidden_size = 300
        bidirectional = True

        encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
                             bidirectional=bidirectional, variable_lengths=True)
        decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else 1,
                             dropout_p=0.2, use_attention=True, bidirectional=bidirectional,
Example #24
0
def apply_gradient_attack_v3(data,
                             model,
                             input_vocab,
                             replace_tokens,
                             field_name,
                             opt,
                             orig_tok_map,
                             idx_to_fname,
                             output_vocab=None,
                             device='cpu'):
    ########################################
    # Parameters that ideally need to come in from opt

    pgd_epochs = opt.u_pgd_epochs

    z_optim = opt.z_optim
    z_epsilon = int(opt.z_epsilon)
    z_init = opt.z_init  # 0: initialize with all zeros; 1: initialize with uniform; 2: debug
    z_learning_rate = opt.z_learning_rate

    u_optim = opt.u_optim
    u_learning_rate = opt.u_learning_rate

    li_use_loss_smoothing = [opt.use_loss_smoothing]
    smoothing_param = opt.smoothing_param

    evaluate_only_on_good_samples = False
    matches_json = '/mnt/outputs/exact_matches_idxs.json'

    vocab_to_use = opt.vocab_to_use
    ##########################################
    u_rand_update_pgd = False  # Optimal site is randomly selected instead of argmax
    u_projection = 2  # 1: simple 0, 1 projection; 2: simplex projection

    li_u_optim_technique = [
        1
    ]  # 1: PGD: SGD with relaxation; 2: signed gradient
    li_u_init_pgd = [
        3
    ]  #list(range(5)) # 0: Original (fixed) init; 1: randomly initalize all tokens; 2: pick PGD optimal randomly instead of argmax; >2: randomly initialize only z=true;
    li_use_u_discrete = [True]
    smooth_iters = 10

    use_cw_loss = False
    choose_best_loss_among_iters = True

    analyze_exact_match_sample = False
    samples_to_analyze = 1
    zlen_debug = 4
    plt_fname = '/mnt/outputs/loss_batch.pkl'
    outpth = '/mnt/outputs/'

    stats = {}
    config_dict = OrderedDict([
        ('version', 'v3'),
        ('pgd_epochs', pgd_epochs),
        ('z_optim', z_optim),
        ('z_epsilon', z_epsilon),
        ('z_init', z_init),
        ('z_learning_rate', z_learning_rate),
        ('evaluate_only_on_good_samples', evaluate_only_on_good_samples),
        ('u_optim', u_optim),
        ('u_learning_rate', u_learning_rate),
        ('u_rand_update_pgd', u_rand_update_pgd),
        ('smooth_iters', smooth_iters),
        ('use_cw_loss', use_cw_loss),
        ('choose_best_loss_among_iters', choose_best_loss_among_iters),
        ('analyze_exact_match_sample', analyze_exact_match_sample),
    ])
    stats['config_dict'] = config_dict
    ########################################

    # This datastructure is meant to return best replacements only for *one* set of best params
    # If using in experiment mode (i.e. itertools.product has mutliple combinations), don't expect consistent
    # results from best_replacements_dataset
    best_replacements_dataset = {}
    '''
	with open(matches_json, 'r') as f:
		exact_matches_file_names = json.load(f) # mapping of file/sample index to file name
	exact_matches_file_names = set([str(e) for e in exact_matches_file_names])
	'''

    for params in itertools.product(li_u_optim_technique, li_u_init_pgd,
                                    li_use_loss_smoothing, li_use_u_discrete):
        pp = pprint.PrettyPrinter(indent=4)
        pp.pprint(config_dict)
        (u_optim_technique, u_init_pgd, use_loss_smoothing,
         use_u_discrete) = params
        od = OrderedDict([
            ('u_optim_technique', u_optim_technique),
            ('u_init_pgd', u_init_pgd),
            ('use_loss_smoothing', use_loss_smoothing),
            ('use_u_discrete', use_u_discrete),
        ])
        pp.pprint(od)
        stats['config_dict2'] = od
        batch_iterator = torchtext.data.BucketIterator(
            dataset=data,
            batch_size=opt.batch_size,
            sort=True,
            sort_within_batch=True,
            sort_key=lambda x: len(x.src),
            device=device,
            repeat=False)
        batch_generator = batch_iterator.__iter__()
        if use_cw_loss:
            loss_obj = AttackLoss(device=device)
        else:
            weight = torch.ones(len(output_vocab.vocab)).half()
            pad = output_vocab.vocab.stoi[output_vocab.pad_token]
            loss_obj = Perplexity(weight, pad)
            if torch.cuda.is_available():
                loss_obj.cuda()
        model.train()

        best_loss_among_iters, best_loss_among_iters_status = {}, {}
        nothing_to_attack, rand_replacement_too_long, tot_attacks, tot_samples = 0, 0, 0, 0
        sample_to_select_idx, pred_to_select, sample_to_select_idx_cnt, sname = None, None, 0, ''

        # a mask of length len(input_vocab) which lists which are valid/invalid tokens
        if vocab_to_use == 1:
            invalid_tokens_mask = get_valid_token_mask(negation=True,
                                                       vocab=input_vocab,
                                                       exclude=[])
        elif vocab_to_use == 2:
            invalid_tokens_mask = [False] * len(input_vocab)

        for bid, batch in enumerate(
                tqdm.tqdm(batch_generator, total=len(batch_iterator))):
            if analyze_exact_match_sample and (sample_to_select_idx_cnt >=
                                               samples_to_analyze):
                continue

            found_sample, zlen, plen, zstr = False, 0, 0, None
            indices = getattr(batch, 'index')
            input_variables, input_lengths = getattr(batch, field_name)
            target_variables = getattr(batch, 'tgt')
            orig_input_variables, orig_lens = getattr(batch, 'src')
            tot_samples += len(getattr(batch, field_name)[1])

            # Do random attack if inputs are too long and will OOM under gradient attack
            if max(getattr(batch, field_name)[1]) > 250:
                rand_replacement_too_long += len(getattr(batch, field_name)[1])
                rand_replacements = get_random_token_replacement_2(
                    input_variables.cpu().numpy(), input_vocab,
                    indices.cpu().numpy(), replace_tokens, opt.distinct,
                    z_epsilon)

                best_replacements_dataset.update(rand_replacements)
                continue

            # too update replacement-variables with max-idx in case this is the iter with the best optimized loss
            update_this_iter = False

            indices = indices.cpu().numpy()
            inputs_oho = Variable(convert_to_onehot(
                input_variables, vocab_size=len(input_vocab), device=device),
                                  requires_grad=True).half()

            #### To compute which samples have exact matches with ground truth in this batch
            if analyze_exact_match_sample or evaluate_only_on_good_samples:
                # decoder_outputs: List[(max_length x decoded_output_sz)]; List length -- batch_sz
                # These steps are common for every batch.
                decoder_outputs, decoder_hidden, other = model(
                    inputs_oho,
                    input_lengths,
                    target_variables,
                    already_one_hot=True)

                output_seqs, ground_truths = [], []

                for i, output_seq_len in enumerate(other['length']):
                    # print(i,output_seq_len)
                    tgt_id_seq = [
                        other['sequence'][di][i].data[0]
                        for di in range(output_seq_len)
                    ]
                    tgt_seq = [
                        output_vocab.vocab.itos[tok] for tok in tgt_id_seq
                    ]
                    output_seqs.append(' '.join([
                        x for x in tgt_seq
                        if x not in ['<sos>', '<eos>', '<pad>']
                    ]))
                    gt = [
                        output_vocab.vocab.itos[tok]
                        for tok in target_variables[i]
                    ]
                    ground_truths.append(' '.join([
                        x for x in gt if x not in ['<sos>', '<eos>', '<pad>']
                    ]))

                other_metrics = calculate_metrics(output_seqs, ground_truths)

                if len(other_metrics['exact_match_idx']) > 0:
                    sample_to_select_idx = other_metrics['exact_match_idx'][0]

                if evaluate_only_on_good_samples:
                    pass
                    if len(other_metrics['good_match_idx']) == 0:
                        continue
                    attack_sample_set = other_metrics['good_match_idx']
                elif sample_to_select_idx is None:
                    continue

            ###############################################
            # Initialize z for the batch
            status_map, z_map, z_all_map, z_np_map, site_map_map, site_map_lookup_map, z_initialized_map,invalid_tokens_mask_map = {}, {}, {}, {}, {}, {}, {}, {}

            for ii in range(inputs_oho.shape[0]):
                replace_map_i, site_map, status = get_all_replacement_toks(
                    input_variables.cpu().numpy()[ii], None, input_vocab,
                    replace_tokens)

                if not status:
                    continue

                site_map_lookup = []
                for cnt, k in enumerate(site_map.keys()):
                    site_map_lookup.append(k)

                if z_init == 0:
                    z_np = np.zeros(len(site_map_lookup)).astype(float)
                elif z_init == 1:
                    z_np = (1 / len(site_map_lookup)) * np.ones(
                        len(site_map_lookup)).astype(float)
                elif z_init == 2:
                    z_np = np.zeros(len(site_map_lookup)).astype(float)
                    z_np[0] = 1

                z = torch.tensor(z_np, requires_grad=True, device=device)
                if len(z.shape) == 1:
                    z = z.unsqueeze(dim=1)

                mask = np.array(input_variables.cpu().numpy()[ii] *
                                [False]).astype(bool)
                for kk in range(len(site_map_lookup)):
                    if not z[kk]:
                        continue
                    m = site_map[site_map_lookup[kk]]
                    mask = np.array(m) | mask

                status_map[ii] = status
                z_map[ii] = z
                z_np_map[ii] = z_np
                z_all_map[ii] = list(mask)
                site_map_map[ii] = site_map
                site_map_lookup_map[ii] = site_map_lookup
                z_initialized_map[ii] = [False] * z_np.shape[0]
                # selected_toks = torch.sum(z * embed, dim=0)  # Element-wise mult

            if analyze_exact_match_sample and (
                    sample_to_select_idx not in z_np_map
                    or len(z_np_map[sample_to_select_idx]) < zlen_debug):
                continue

            new_inputs, site_map_map, z_all_map, input_lengths, sites_to_fix_map = replace_toks_batch(
                input_variables.cpu().numpy(), indices, z_map, site_map_map,
                site_map_lookup_map, {}, field_name, input_vocab, orig_tok_map,
                idx_to_fname)
            input_lengths = torch.tensor(input_lengths, device=device)
            inputs_oho_orig = Variable(convert_to_onehot(
                torch.tensor(new_inputs, device=device),
                vocab_size=len(input_vocab),
                device=device),
                                       requires_grad=True).half()
            inputs_oho_orig = modify_onehot(inputs_oho_orig, site_map_map,
                                            sites_to_fix_map, device)

            # Initialize input_hot_grad
            # This gets updated for each i with (not z_all_map) tokens being switched to x_orig
            if u_init_pgd == 1:
                input_h = inputs_oho_orig[0][0].clone().detach()
            elif u_init_pgd == 2:
                input_h = torch.zeros(inputs_oho_orig[0][0].shape).half()
            elif u_init_pgd == 3:
                valid_tokens = [not t for t in invalid_tokens_mask[:]]
                input_h = inputs_oho_orig[0][0].clone().detach()
                input_h[valid_tokens] = 1 / sum(valid_tokens)
                input_h[invalid_tokens_mask] = 0
            elif u_init_pgd == 4:
                input_h = (1 - inputs_oho_orig[0][0].clone()) / (
                    len(invalid_tokens_mask) - 1)
            input_hot_grad = input_h.clone().detach().requires_grad_(
                True).repeat(inputs_oho_orig.shape[0],
                             inputs_oho_orig.shape[1]).view(
                                 inputs_oho_orig.shape)

            ##################################################
            for i in range(inputs_oho_orig.shape[0]):
                if i not in status_map:
                    continue

                if analyze_exact_match_sample and (i != sample_to_select_idx):
                    continue

                fn_name = str(indices[i])

                input_hot_orig_i = inputs_oho_orig[i].unsqueeze(
                    0
                )  # is not affected by gradients; okay to copy by reference
                input_hot_grad_i = input_hot_grad[i].unsqueeze(0)
                il_i = input_lengths[i].unsqueeze(0)
                tv_i = target_variables[i].unsqueeze(0)
                site_map_lookup = site_map_lookup_map[i]
                z = z_map[i]
                site_map = site_map_map[i]
                z_all = z_all_map[i]

                if z_epsilon == 0:
                    z_epsilon = z.shape[0]

                if i not in status_map:
                    nothing_to_attack += 1
                    continue

                tot_attacks += 1

                if analyze_exact_match_sample:
                    sample_to_select_idx_cnt += 1
                    sname = fn_name
                    found_sample = True
                    print('found {}; z len {}'.format(sname, len(z_np_map[i])))
                    print([input_vocab.itos[t] for t in new_inputs[i]])
                    print([input_vocab.itos[t] for t in input_variables[i]])
                    zlen = sum(z_all_map[i])
                    plen = len(z_all_map[i])
                    zstr = str(z_np_map[i])
                    print(zstr)

                # Revert all (not z_mask) tokens to x_orig
                # Take care with cloning to ensure gradients are not shared.
                not_z_all = [not t for t in z_all]
                input_hot_grad_i[0][not_z_all] = input_hot_orig_i[0][
                    not_z_all].detach().clone().requires_grad_(True)

                embed = None
                for sm in site_map_lookup:
                    if embed is None:
                        embed = np.array(site_map[sm]).astype(float)
                    else:
                        embed = np.vstack(
                            (embed, np.array(site_map[sm]).astype(float)))
                embed = torch.tensor(
                    embed, requires_grad=True,
                    device=device)  # values don't get updated/modified
                if len(embed.shape) == 1:
                    embed = embed.unsqueeze(dim=0)

                batch_loss_list_per_iter, best_replacements_sample = [], {}

                # Begin optim iters
                for j in range(pgd_epochs):
                    # Forward propagation
                    # decoder_outputs: List[(max_length x decoded_output_sz)]; List length -- batch_sz
                    selected_toks = torch.sum(z * embed,
                                              dim=0)  # Element-wise mult
                    selected_toks = selected_toks.repeat(
                        input_hot_grad_i.shape[2], 1).T.unsqueeze(0).half()
                    perturbed_sample = selected_toks * input_hot_grad_i + (
                        1 - selected_toks) * input_hot_orig_i

                    # Calculate loss
                    if use_u_discrete:
                        a = perturbed_sample.argmax(2)
                        m = torch.zeros(perturbed_sample.shape,
                                        requires_grad=True,
                                        device=device).scatter(
                                            2, a.unsqueeze(2), 1.0).half()
                        decoder_outputs, decoder_hidden, other = model(
                            m, il_i, tv_i, already_one_hot=True)
                    else:
                        decoder_outputs, decoder_hidden, other = model(
                            perturbed_sample, il_i, tv_i, already_one_hot=True)
                    loss, l_scalar, sample_wise_loss_per_batch = calculate_loss(
                        use_cw_loss, loss_obj, decoder_outputs, other, tv_i)

                    if analyze_exact_match_sample:  # sample_to_select_idx is not None at this stage
                        batch_loss_list_per_iter.append(
                            sample_wise_loss_per_batch)
                    else:
                        batch_loss_list_per_iter.append(
                            sample_wise_loss_per_batch)

                    if (fn_name not in best_loss_among_iters) or (
                            best_loss_among_iters[fn_name] <
                            sample_wise_loss_per_batch[0]):
                        best_loss_among_iters_status[fn_name] = True
                        best_loss_among_iters[
                            fn_name] = sample_wise_loss_per_batch[0]
                    else:
                        best_loss_among_iters_status[fn_name] = False

                    invalid_tokens_mask_ij = invalid_tokens_mask[:]

                    # Forward propagation
                    # Calculate loss on the continuous value vectors
                    if not use_loss_smoothing:
                        decoder_outputs, decoder_hidden, other = model(
                            perturbed_sample, il_i, tv_i, already_one_hot=True)
                        loss, l_scalar, sample_wise_loss_per_batch = calculate_loss(
                            use_cw_loss, loss_obj, decoder_outputs, other,
                            tv_i)

                        # update loss and backprop
                        model.zero_grad()
                        input_hot_grad_i.retain_grad()
                        z.retain_grad()
                        loss.backward(retain_graph=True)

                        grads_oh_i = input_hot_grad_i.grad
                        gradients = grads_oh_i.detach().cpu().numpy()[0]
                        grads_z_i = z.grad
                    else:
                        b_loss, smooth_grads_oh, smooth_grads_z = [], None, None
                        mask_optimisee = torch.sum(
                            z * embed,
                            dim=0).cpu().detach().numpy().astype(bool)
                        for si in range(smooth_iters):
                            smooth_hot_grad_i = input_hot_grad_i.clone()
                            noise = smoothing_param * torch.empty(
                                input_hot_grad_i.shape, device=device).normal_(
                                    mean=0, std=1).half()
                            smooth_hot_grad_i[:,
                                              mask_optimisee, :] = smooth_hot_grad_i[:,
                                                                                     mask_optimisee, :] + noise[:,
                                                                                                                mask_optimisee, :]
                            smooth_hot_grad_i = input_hot_grad_i + noise
                            smooth_input = selected_toks * smooth_hot_grad_i + (
                                1 - selected_toks) * input_hot_orig_i
                            smooth_decoder_outputs, smooth_decoder_hidden, smooth_other = model(
                                smooth_input, il_i, tv_i, already_one_hot=True)
                            loss, l_scalar, sample_wise_loss_per_batch = calculate_loss(
                                use_cw_loss, loss_obj, smooth_decoder_outputs,
                                smooth_other, tv_i)

                            # update loss and backprop
                            model.zero_grad()
                            smooth_hot_grad_i.retain_grad()
                            z.retain_grad()
                            loss.backward(retain_graph=True)

                            if smooth_grads_oh is None:
                                smooth_grads_oh = smooth_hot_grad_i.grad
                                smooth_grads_z = z.grad
                            else:
                                smooth_grads_oh += smooth_hot_grad_i.grad
                                smooth_grads_z += z.grad

                        grads_oh_i = smooth_grads_oh / smooth_iters
                        gradients = grads_oh_i.detach().cpu().numpy()[0]
                        grads_z_i = smooth_grads_z / smooth_iters

                    # Optimize input_hot_grad_i
                    if u_optim:
                        if analyze_exact_match_sample:
                            print('-- u optim --')
                        for idx in range(z.shape[0]):
                            # if z_np[idx] == 0:
                            #	continue
                            mask = site_map[site_map_lookup[idx]]
                            # Can take a mean across all tokens for which z=1
                            # Currently, this mean is for all tokens for which z_i=1
                            avg_tok_grads = np.mean(gradients[mask], axis=0)
                            repl_tok_idx = site_map_lookup[idx]
                            # print(repl_tok_idx)
                            repl_tok = input_vocab.itos[repl_tok_idx]
                            # print("repl tok: {}".format(repl_tok))
                            nabla = avg_tok_grads

                            if u_optim_technique == 2:
                                nabla = np.sign(nabla)

                            # PGD
                            step = u_learning_rate / np.sqrt(j + 1) * nabla
                            if use_cw_loss:
                                step = -1 * step

                            # any one entry of the masked entries
                            # initalize to 0s for first entry
                            input_h = input_hot_grad_i[0][mask, :][
                                0, :].detach().cpu().numpy()
                            '''
							print("z idx {}".format(idx))
							print(np.expand_dims(input_h, axis=0).shape)
							print(np.argmax(np.expand_dims(input_h, axis=0), axis=1))
							'''
                            input_h = input_h + step

                            # projection
                            if u_projection == 1:
                                optim_input = np.clip(input_h, 0, 1)
                            elif u_projection == 2:
                                # simplex projection
                                fmu = lambda mu, a=input_h: np.sum(
                                    np.maximum(0, a - mu)) - 1
                                mu_opt = bisection(fmu, -1, 1, 30)
                                if mu_opt is None:
                                    mu_opt = 0  # assigning randomly to 0
                                optim_input = np.maximum(0, input_h - mu_opt)
                                # print(fmu(mu_opt))

                            # projection onto only valid tokens. Rest are set to 0
                            optim_input[invalid_tokens_mask_ij] = 0
                            # print(sum(invalid_tokens_mask_map))

                            if u_rand_update_pgd:
                                max_idx = random.randrange(
                                    optim_input.shape[0])
                            else:
                                max_idx = np.argmax(optim_input)

                            # Update to replacements with best loss so far
                            if choose_best_loss_among_iters:
                                if best_loss_among_iters_status[fn_name]:
                                    best_replacements_sample[
                                        repl_tok] = input_vocab.itos[max_idx]
                            else:
                                best_replacements_sample[
                                    repl_tok] = input_vocab.itos[max_idx]

                            # Ensure other z's for this index don't use this replacement token
                            invalid_tokens_mask_ij[
                                max_idx] = True  # setting it as invalid being True

                            # Update optim_input
                            input_hot_grad_i[0][mask, :] = torch.tensor(
                                optim_input, requires_grad=True, device=device)

                        if analyze_exact_match_sample:
                            print('Best loss: ',
                                  best_loss_among_iters[fn_name])
                            print("Loss: {}".format(batch_loss_list_per_iter))
                            print(best_replacements_sample)

                    # Optimize z
                    if z_optim:
                        # print('Optimizing z')
                        if analyze_exact_match_sample:
                            print('-- z optim --')
                            print(z.squeeze().cpu().detach().numpy())
                            print("Constraint: {}".format(z_epsilon))

                        # Gradient ascent. Maximize CE loss
                        a = z + z_learning_rate / np.sqrt(j + 1) * grads_z_i
                        if analyze_exact_match_sample:
                            print(a.squeeze().cpu().detach().numpy())
                        a_np = a.cpu().detach().numpy()
                        fmu = lambda mu, a=a_np, epsilon=z_epsilon: np.sum(
                            a - mu) - epsilon
                        mu_opt = bisection(fmu, 0, np.max(a_np), 50)
                        if mu_opt is None:
                            mu_opt = 0  # assigning randomly to 0
                        if mu_opt > 0:
                            z = torch.clamp(a - mu_opt, 0, 1)
                        else:
                            z = torch.clamp(a, 0, 1)
                        # one = torch.ones(z.shape, device=device, requires_grad=True)
                        # zero = torch.zeros(z.shape, device=device, requires_grad=True)
                        # z = torch.where(z > 0.5, one, zero)
                        if analyze_exact_match_sample:
                            print(z.squeeze().cpu().detach().numpy())
                            print('---')

                # end optim iterations

                # Select a final z
                z_final_soft = z.squeeze(dim=1).detach().cpu().numpy()

                z_final = np.random.binomial(1, z_final_soft)
                if analyze_exact_match_sample:
                    print('Final z -- ')
                    print(z_final_soft)
                    print(z_final)

                if sum(z_final) == 0 or sum(z_final) > z_epsilon:
                    if sum(z_final) == 0:
                        z_final_soft_idx = np.argsort(z_final_soft)[::-1][0]
                    elif sum(z_final) > z_epsilon:
                        z_final_soft_idx = np.argsort(
                            z_final_soft)[::-1][:z_epsilon]
                    z_final = np.zeros(z_final.shape)
                    z_final[z_final_soft_idx] = 1

                if analyze_exact_match_sample:
                    print('constraint: {}'.format(z_epsilon))
                    print('after constraint: {}'.format(z_final))

                for ix in range(z_final.shape[0]):
                    if z_final[ix] == 0:
                        # Find out the replace token corresponding to this site
                        remove_key = input_vocab.itos[site_map_lookup[ix]]
                        # Remove this token from best_replacements_sample
                        best_replacements_sample.pop(remove_key, None)

                if analyze_exact_match_sample:
                    print('Final best repalcements', best_replacements_sample)

                if analyze_exact_match_sample:
                    if found_sample:
                        if len(batch_loss_list_per_iter) > 0:
                            out_str = 'ss{}_zlen-{}_n-{}_zstr-{}_opt-{}_lr-{}_uinit-{}_smooth-{}_udisc-{}'.format(
                                sname, zlen, plen, zstr, u_optim_technique,
                                u_learning_rate, u_init_pgd,
                                int(use_loss_smoothing), int(use_u_discrete))
                            print(out_str)
                            loss_plot(batch_loss_list_per_iter,
                                      os.path.join(outpth, out_str))

                best_replacements_dataset[fn_name] = best_replacements_sample

    print('Skipped and reverted to random attacks: {}/{} ({})'.format(
        rand_replacement_too_long, tot_samples,
        round(100 * rand_replacement_too_long / tot_samples, 2)))
    print('Nothing to attack: {}/{} ({})'.format(
        nothing_to_attack, tot_attacks,
        round(100 * nothing_to_attack / tot_attacks, 2)))
    print('----------------')
    print("# of samples attacked: {}".format(
        len(best_replacements_dataset.keys())))

    stats['reverted_to_random_attacks_pc'] = round(
        100 * rand_replacement_too_long / tot_samples, 2)
    stats['nothing_to_attack_pc'] = round(
        100 * nothing_to_attack / tot_attacks, 2)
    stats['n_samples_attacked'] = len(best_replacements_dataset.keys())

    if analyze_exact_match_sample:
        kzs = best_replacements_dataset.keys()
        for kz in kzs:
            print("{}::{}".format(kz, best_replacements_dataset[kz]))
        print('====')

    best_replacements_dataset, avg_replaced = get_all_replacements(
        best_replacements_dataset, field_name, orig_tok_map, idx_to_fname,
        True)

    if analyze_exact_match_sample:
        for kz in kzs:
            print("{}::{}".format(kz, best_replacements_dataset[kz]))

    print('\n# tokens optimized on an average: {}'.format(avg_replaced))
    stats['n_tokens_optimized_avg'] = avg_replaced
    print("\n# of samples attacked post processing: {}\n=======".format(
        len(best_replacements_dataset.keys())))
    stats['n_samples_attacked_post_processing'] = len(
        best_replacements_dataset.keys())

    return best_replacements_dataset, stats