def train(args, train_dataset, model, tokenizer, teacher=None):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            if teacher is not None:
                teacher.eval()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'input_ids':       batch[0],
                      'attention_mask':  batch[1], 
                      'start_positions': batch[3], 
                      'end_positions':   batch[4]}
            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
            if args.model_type in ['xlnet', 'xlm']:
                inputs.update({'cls_index': batch[5],
                               'p_mask':       batch[6]})
            outputs = model(**inputs)
            loss, start_logits_stu, end_logits_stu = outputs

            # Distillation loss
            if teacher is not None:
                if 'token_type_ids' not in inputs:
                    inputs['token_type_ids'] = None if args.teacher_type == 'xlm' else batch[2]
                with torch.no_grad():
                    start_logits_tea, end_logits_tea = teacher(input_ids=inputs['input_ids'],
                                                               token_type_ids=inputs['token_type_ids'],
                                                               attention_mask=inputs['attention_mask'])
                assert start_logits_tea.size() == start_logits_stu.size()
                assert end_logits_tea.size() == end_logits_stu.size()
                
                loss_fct = nn.KLDivLoss(reduction='batchmean')
                loss_start = loss_fct(F.log_softmax(start_logits_stu/args.temperature, dim=-1),
                                      F.softmax(start_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
                loss_end = loss_fct(F.log_softmax(end_logits_stu/args.temperature, dim=-1),
                                    F.softmax(end_logits_tea/args.temperature, dim=-1)) * (args.temperature**2)
                loss_ce = (loss_start + loss_end)/2.

                loss = args.alpha_ce*loss_ce + args.alpha_squad*loss

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
예제 #2
0
loss = nn.NLLLoss2d()
# input is of size N x C x height x width
input = autograd.Variable(torch.randn(3, 16, 10, 10))
# each element in target has to have 0 <= value < C
target = autograd.Variable(torch.LongTensor(3, 8, 8).random_(0, 4))
output = loss(m(input), target)
output.backward()


# =============================================================================
# 
# torch.nn.KLDivLoss
# loss(x,target)=1/n∑(targeti∗(log(targeti)−xi))
# =============================================================================

loss = nn.KLDivLoss(size_average=False)
batch_size = 5
log_probs1 = F.log_softmax(torch.randn(batch_size, 10), 1)
probs2 = F.softmax(torch.randn(batch_size, 10), 1)
loss(log_probs1, probs2) / batch_size



# =============================================================================
# torch.nn.BCELoss
#loss(o,t)=−1/n∑i(t[i]∗log(o[i])+(1−t[i])∗log(1−o[i]))
# =============================================================================
m = nn.Sigmoid()
loss = nn.BCELoss()
input = autograd.Variable(torch.randn(3), requires_grad=True)
target = autograd.Variable(torch.FloatTensor(3).random_(2))
def train(gen_path, save_pth):
    model = Resnet18(n_classes=n_classes, pre_act=pre_act)
    model.train()
    model.cuda()
    criteria = nn.KLDivLoss(reduction='batchmean')
    generator = Resnet18(n_classes=10)
    state_dict = torch.load(gen_path)
    generator.load_state_dict(state_dict)
    generator.train()
    generator.cuda()

    batchsize = 256
    n_workers = 8
    dltrain = get_train_loader(
        batch_size=batchsize,
        num_workers=n_workers,
        dataset=ds_name,
        pin_memory=True
    )

    lr0 = 2e-1
    lr_eta = 1e-5
    momentum = 0.9
    wd = 5e-4
    n_epochs = 50
    n_warmup_epochs = 10
    warmup_start_lr = 1e-5
    warmup_method = 'linear'
    optim = torch.optim.SGD(
        model.parameters(),
        lr=lr0,
        momentum=momentum,
        weight_decay=wd
    )
    lr_sheduler = WarmupCosineAnnealingLR(
        optim,
        warmup_start_lr=warmup_start_lr,
        warmup_epochs=n_warmup_epochs,
        warmup=warmup_method,
        max_epochs=n_epochs,
        cos_eta=lr_eta,
    )

    for e in range(n_epochs):
        tic = time.time()
        model.train()
        lr_sheduler.step()
        loss_epoch = []
        for _, (ims, _) in enumerate(dltrain):
            ims = ims.cuda()
            # generate labels
            with torch.no_grad():
                lbs = generator(ims).clone()
                lbs = torch.softmax(lbs, dim=1)
            optim.zero_grad()
            if mixup:
                bs = ims.size(0)
                idx = torch.randperm(bs)
                lam = np.random.beta(mixup_alpha, mixup_alpha)
                ims_mix = lam * ims + (1.-lam) * ims[idx]
                logits = model(ims_mix)
                probs = F.log_softmax(logits, dim=1)
                loss1 = criteria(probs, lbs)
                loss2 = criteria(probs, lbs[idx])
                loss = lam * loss1 + (1.-lam) * loss2
            else:
                logits = model(ims)
                probs = F.log_softmax(logits, dim=1)
                loss = criteria(probs, lbs)
            loss.backward()
            loss_epoch.append(loss.item())
            optim.step()
        model.eval()
        acc = evaluate(model, verbose=False)
        toc = time.time()
        msg = 'epoch: {}, loss: {:.4f}, lr: {:.4f}, acc: {:.4f}, time: {:.2f}'.format(
            e,
            sum(loss_epoch)/len(loss_epoch),
            list(optim.param_groups)[0]['lr'],
            acc,
            toc - tic
        )
        print(msg)

    model.cpu()
    if hasattr(model, 'module'):
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()
    torch.save(state_dict, save_pth)
    return model
예제 #4
0
def main():
    global args
    args = parse_args()
    args.input_dim, args.mem_dim = 300, 150
    args.hidden_dim, args.num_classes = 50, 5
    args.cuda = args.cuda and torch.cuda.is_available()
    if args.sparse and args.wd != 0:
        print('Sparsity and weight decay are incompatible, pick one!')
        exit()
    print(args)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    numpy.random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = True
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    train_dir = os.path.join(args.data, 'train/')
    dev_dir = os.path.join(args.data, 'dev/')
    test_dir = os.path.join(args.data, 'test/')

    # write unique words from all token files
    sick_vocab_file = os.path.join(args.data, 'sick.vocab')
    if not os.path.isfile(sick_vocab_file):
        token_files_a = [
            os.path.join(split, 'a.toks')
            for split in [train_dir, dev_dir, test_dir]
        ]
        token_files_b = [
            os.path.join(split, 'b.toks')
            for split in [train_dir, dev_dir, test_dir]
        ]
        token_files = token_files_a + token_files_b
        sick_vocab_file = os.path.join(args.data, 'sick.vocab')
        build_vocab(token_files, sick_vocab_file)

    # get vocab object from vocab file previously written
    vocab = Vocab(filename=sick_vocab_file,
                  data=[
                      Constants.PAD_WORD, Constants.UNK_WORD,
                      Constants.BOS_WORD, Constants.EOS_WORD
                  ])
    print('==> SICK vocabulary size : %d ' % vocab.size())

    # load SICK dataset splits
    train_file = os.path.join(args.data, 'sick_train.pth')
    if os.path.isfile(train_file):
        train_dataset = torch.load(train_file)
    else:
        train_dataset = SICKDataset(train_dir, vocab, args.num_classes)
        torch.save(train_dataset, train_file)
    print('==> Size of train data   : %d ' % len(train_dataset))
    dev_file = os.path.join(args.data, 'sick_dev.pth')
    if os.path.isfile(dev_file):
        dev_dataset = torch.load(dev_file)
    else:
        dev_dataset = SICKDataset(dev_dir, vocab, args.num_classes)
        torch.save(dev_dataset, dev_file)
    print('==> Size of dev data     : %d ' % len(dev_dataset))
    test_file = os.path.join(args.data, 'sick_test.pth')
    if os.path.isfile(test_file):
        test_dataset = torch.load(test_file)
    else:
        test_dataset = SICKDataset(test_dir, vocab, args.num_classes)
        torch.save(test_dataset, test_file)
    print('==> Size of test data    : %d ' % len(test_dataset))

    # initialize model, criterion/loss_function, optimizer
    model = SimilarityTreeLSTM(args.cuda, vocab.size(), args.input_dim,
                               args.mem_dim, args.hidden_dim, args.num_classes,
                               args.sparse)
    criterion = nn.KLDivLoss()
    if args.cuda:
        model.cuda(), criterion.cuda()
    if args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(),
                                  lr=args.lr,
                                  weight_decay=args.wd)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              weight_decay=args.wd)
    metrics = Metrics(args.num_classes)

    # for words common to dataset vocab and GLOVE, use GLOVE vectors
    # for other words in dataset vocab, use random normal vectors
    emb_file = os.path.join(args.data, 'sick_embed.pth')
    if os.path.isfile(emb_file):
        emb = torch.load(emb_file)
    else:
        # load glove embeddings and vocab
        glove_vocab, glove_emb = load_word_vectors(
            os.path.join(args.glove, 'glove.840B.300d'))
        print('==> GLOVE vocabulary size: %d ' % glove_vocab.size())
        emb = torch.Tensor(vocab.size(),
                           glove_emb.size(1)).normal_(-0.05, 0.05)
        # zero out the embeddings for padding and other special words if they are absent in vocab
        for idx, item in enumerate([
                Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD,
                Constants.EOS_WORD
        ]):
            emb[idx].zero_()
        for word in vocab.labelToIdx.keys():
            if glove_vocab.getIndex(word):
                emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(
                    word)]
        torch.save(emb, emb_file)
    # plug these into embedding matrix inside model
    if args.cuda:
        emb = emb.cuda()
    model.childsumtreelstm.emb.state_dict()['weight'].copy_(emb)

    # create trainer object for training and testing
    trainer = Trainer(args, model, criterion, optimizer)

    best = -float('inf')
    for epoch in range(args.epochs):
        train_loss = trainer.train(train_dataset)
        train_loss, train_pred = trainer.test(train_dataset)
        dev_loss, dev_pred = trainer.test(dev_dataset)
        test_loss, test_pred = trainer.test(test_dataset)

        train_pearson = metrics.pearson(train_pred, train_dataset.labels)
        train_mse = metrics.mse(train_pred, train_dataset.labels)
        print('==> Train    Loss: {}\tPearson: {}\tMSE: {}'.format(
            train_loss, train_pearson, train_mse))
        dev_pearson = metrics.pearson(dev_pred, dev_dataset.labels)
        dev_mse = metrics.mse(dev_pred, dev_dataset.labels)
        print('==> Dev      Loss: {}\tPearson: {}\tMSE: {}'.format(
            dev_loss, dev_pearson, dev_mse))
        test_pearson = metrics.pearson(test_pred, test_dataset.labels)
        test_mse = metrics.mse(test_pred, test_dataset.labels)
        print('==> Test     Loss: {}\tPearson: {}\tMSE: {}'.format(
            test_loss, test_pearson, test_mse))

        if best < test_pearson:
            best = test_pearson
            checkpoint = {
                'model': trainer.model.state_dict(),
                'optim': trainer.optimizer,
                'pearson': test_pearson,
                'mse': test_mse,
                'args': args,
                'epoch': epoch
            }
            print('==> New optimum found, checkpointing everything now...')
            torch.save(
                checkpoint,
                '%s.pt' % os.path.join(args.save, args.expname + '.pth'))
    def __init__(self, params: dict, dataset: LmSeqsDataset,
                 token_probs: torch.tensor, student: nn.Module,
                 teacher: nn.Module):
        logger.info("Initializing Distiller")
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths,
                                           k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler,
                                          group_ids=groups,
                                          batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler,
                                   batch_size=params.batch_size,
                                   drop_last=False)

        self.dataloader = DataLoader(dataset=dataset,
                                     batch_sampler=sampler,
                                     collate_fn=dataset.batch_sequences)

        self.temperature = params.temperature
        assert self.temperature > 0.0

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_clm = params.alpha_clm
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos

        self.mlm = params.mlm
        if self.mlm:
            logger.info("Using MLM loss for LM step.")
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor(
                [params.word_mask, params.word_keep, params.word_rand])
            self.pred_probs = self.pred_probs.to(
                f"cuda:{params.local_rank}"
            ) if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(
                f"cuda:{params.local_rank}"
            ) if params.n_gpu > 0 else token_probs
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
            logger.info("Using CLM loss for LM step.")

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        self.last_loss_clm = 0
        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

        logger.info("--- Initializing model optimizer")
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params.gradient_accumulation_steps *
                params.n_epoch) + 1)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                params.weight_decay,
            },
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                0.0,
            },
        ]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel

                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel

                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True,
                )

        self.is_master = params.is_master
        if self.is_master:
            logger.info("--- Initializing Tensorboard")
            self.tensorboard = SummaryWriter(
                log_dir=os.path.join(self.dump_path, "log", "train"))
            self.tensorboard.add_text(tag="config/training",
                                      text_string=str(self.params),
                                      global_step=0)
            self.tensorboard.add_text(tag="config/student",
                                      text_string=str(self.student_config),
                                      global_step=0)
예제 #6
0
def loss_fn_two_labels_min(outputs, labels, num_of_classes):
    """
        Compute the loss given outputs and labels.
        we will achieve max KL dist between out_bef_filt and lab_aft_filt by the following:
        we wish the output to be not equal to lab_aft_filt.
        given lab_aft_filt in [0-9] we can create a list of all other possible labels.
        we will get (num_of_classes-1) labels per each sample.
        we then calculate KL loss between out_bef_filt and each other possible label and return
        the mininal KL distance to a cetrain label which is not lab_aft_filt

        Args:
            outputs: (Variable) dimension batch_size x 10 - output of the model
            labels: (Variable) dimension batch_size, where each element is a value in [0- 9]
            num_of_classes: (int) value describing number of different classes (10)

        Returns:
            loss (Variable): loss for all images in the batch
    """

    # kl_criterion = nn.KLDivLoss(size_average=True, reduce=True)
    kl_criterion = nn.KLDivLoss()
    min_entropy_criterion = HLoss()

    label_before_filter = torch.index_select(
        labels, 1, torch.tensor([0], device=labels.device))
    label_after_filter = torch.index_select(
        labels, 1, torch.tensor([1], device=labels.device))

    all_labels_mat = np.arange(num_of_classes) * np.ones(
        (outputs.size()[0], 1), dtype=int)

    temp_bef = all_labels_mat - label_before_filter.cpu().numpy()
    temp_aft = all_labels_mat - label_after_filter.cpu().numpy()

    other_labels_before = torch.from_numpy(
        all_labels_mat[np.nonzero(temp_bef)].reshape(outputs.size()[0],
                                                     num_of_classes - 1))
    other_labels_after = torch.from_numpy(
        all_labels_mat[np.nonzero(temp_aft)].reshape(outputs.size()[0],
                                                     num_of_classes - 1))

    other_labels_before = other_labels_before.type(torch.LongTensor)

    many_hot_vector_before_filter = convert_int_to_one_hot_vector(
        other_labels_before, num_of_classes)
    one_hot_vector_after_filter = convert_int_to_one_hot_vector(
        label_after_filter, num_of_classes)
    one_hot_vector_before_filter = convert_int_to_one_hot_vector(
        label_before_filter, num_of_classes)  # unneeded

    out_before_filter = torch.index_select(
        outputs, 1, torch.tensor(list(range(10)), device=outputs.device))
    out_after_filter = torch.index_select(
        outputs, 1, torch.tensor(list(range(10, 20)), device=outputs.device))

    # min_ind = 0

    # for each option of 9 other labels (all besides the label after filter)
    # calculate the kl distance from the output and keep the minimal one for the loss function

    min_dist = kl_criterion(out_before_filter,
                            many_hot_vector_before_filter[:, 0])
    # ind_array = np.random.permutation(num_of_classes-1).tolist()
    for i in range(num_of_classes - 1):
        # other_one_hot_vector_before_filter = many_hot_vector_before_filter[:, ind_array[i]]
        other_one_hot_vector_before_filter = many_hot_vector_before_filter[:,
                                                                           i]
        dist_before = kl_criterion(out_before_filter,
                                   other_one_hot_vector_before_filter)
        if dist_before < min_dist:
            min_dist = dist_before
            # min_ind = i

    func = kl_criterion(out_after_filter, one_hot_vector_after_filter) + \
            min_dist + min_entropy_criterion(out_before_filter)

    return func
예제 #7
0
    train.samples = dataset.get_k_fold_format_dataset(args.fold, 5)[0]

    trainloader = torch.utils.data.DataLoader(train,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=8)
    dataloaders_dict = {'train': trainloader}
    device = torch.device(
        args.device)  # if torch.cuda.is_available() else "cpu")
    teacher_model.to(device).eval()
    # Send the model to GPU
    model = model.to(device)
    model.train()
    params_to_update = model.parameters()
    criterions = [
        nn.KLDivLoss(reduction='mean'),
        nn.CrossEntropyLoss(reduction='mean')
    ]
    # Train and evaluate
    model = train_model(model=model,
                        teacher_model=teacher_model,
                        dataloaders=dataloaders_dict,
                        criterions=criterions,
                        params_to_update=params_to_update,
                        num_epochs=args.num_epochs,
                        is_inception=is_inception,
                        best_acc=0,
                        early_stop_round=args.early_stop_round,
                        lr=args.learning_rate,
                        batch_update=1)
    torch.save(model.state_dict(),
예제 #8
0
    def __init__(self, args):
        super(compAggWikiqa, self).__init__()

        self.mem_dim = args.mem_dim
        # self.att_dim = args.att_dim
        self.cov_dim = args.cov_dim
        self.learning_rate = args.learning_rate
        self.batch_size = args.batch_size
        self.emb_dim = args.wvecDim
        self.task = args.task
        self.numWords = args.numWords
        self.dropoutP = args.dropoutP
        self.grad = args.grad
        self.visualize = args.visualize
        self.emb_lr = args.emb_lr
        self.emb_partial = args.emb_partial
        self.comp_type = args.comp_type
        self.window_sizes = args.window_sizes
        self.window_large = args.window_large
        self.gpu = args.gpu

        self.best_score = 0

        self.emb_vecs = nn.Embedding(self.numWords, self.emb_dim)
        self.emb_vecs.weight.data = tr.loadVacab2Emb(self.task)

        # self.proj_module_master = self.new_proj_module()
        self.att_module_master = self.new_att_module()

        if self.comp_type == "mul":
            self.sim_sg_module = self.new_sim_mul_module()
        else:
            Exception("The word matching method is not provided")

        self.conv_module = self.new_conv_module()

        mem_dim = self.mem_dim

        class TempNet(nn.Module):
            def __init__(self, mem_dim):
                super(TempNet, self).__init__()
                self.layer1 = nn.Linear(mem_dim, 1)

            def forward(self, input):
                var1 = self.layer1(input)
                var1 = var1.view(-1)
                out = F.log_softmax(var1, dim=0)
                return out

        self.soft_module = TempNet(mem_dim)

        # self.join_module = lambda x: torch.cat(x, 0)

        self.optim_state = {"learningRate": self.learning_rate}
        self.criterion = nn.KLDivLoss()

        # That's a bug in pytorch, Container.py's parameters()
        def get_container_parameters(container):
            w = []
            gw = []
            for module in container.modules:
                mparam = list(module.parameters())
                if mparam:
                    w.extend(mparam[0])
                    gw.extend(mparam[1])
            if not w:
                return
            return w, gw

        # self.params, self.grad_params = get_container_parameters(modules)
        # print("Parameter size: %s" % self.params[0].size())
        # self.best_params = self.params.copy()  # TODO: revisit
        self.dropout_modules = [None] * 2
        self.proj_modules = [None] * 2
        # self.att_modules = [None] * 2

        self.proj_modules = self.new_proj_module()
        self.dropout_modules = nn.Dropout(self.dropoutP)
예제 #9
0
def kld_loss(true, pred):
    return nn.KLDivLoss()(pred, true)
예제 #10
0
    def get_images(self, net_student=None, targets=None):
        print("get_images call")

        net_teacher = self.net_teacher
        use_fp16 = self.use_fp16
        save_every = self.save_every

        kl_loss = nn.KLDivLoss(reduction='batchmean').cuda()
        local_rank = torch.cuda.current_device()
        best_cost = 1e4
        criterion = self.criterion

        # setup target labels
        if targets is None:
            # only works for classification now, for other tasks need to provide target vector
            targets = torch.LongTensor([random.randint(0, 999) for _ in range(self.bs)]).to('cuda')
            if not self.random_label:
                # preselected classes, good for ResNet50v1.5
                targets = [1, 933, 946, 980, 25, 63, 92, 94, 107, 985, 151, 154, 207, 250, 270, 277, 283, 292, 294, 309,
                           311, 325, 340, 360, 386, 402, 403, 409, 530, 440, 468, 417, 590, 670, 817, 762, 920, 949,
                           963, 967, 574, 487]
                targets = targets[:10]

                targets = torch.LongTensor(targets * (int(self.bs / len(targets)))).to('cuda')

        img_original = self.image_resolution

        data_type = torch.half if use_fp16 else torch.float
        inputs = torch.randn((self.bs, 3, img_original, img_original), requires_grad=True, device='cuda',
                             dtype=data_type)
        pooling_function = nn.modules.pooling.AvgPool2d(kernel_size=2)

        if self.setting_id == 0:
            skipfirst = False
        else:
            skipfirst = True

        iteration = 0
        for lr_it, lower_res in enumerate([2, 1]):
            if lr_it == 0:
                iterations_per_layer = 2000
            else:
                iterations_per_layer = 1000 if not skipfirst else 2000
                if self.setting_id == 2:
                    iterations_per_layer = 20000

            if lr_it == 0 and skipfirst:
                continue

            lim_0, lim_1 = self.jitter // lower_res, self.jitter // lower_res

            if self.setting_id == 0:
                # multi resolution, 2k iterations with low resolution, 1k at normal, ResNet50v1.5 works the best, ResNet50 is ok
                optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.5, 0.9], eps=1e-8)
                do_clip = True
            elif self.setting_id == 1:
                # 2k normal resolultion, for ResNet50v1.5; Resnet50 works as well
                optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.5, 0.9], eps=1e-8)
                do_clip = True
            elif self.setting_id == 2:
                # 20k normal resolution the closes to the paper experiments for ResNet50
                optimizer = optim.Adam([inputs], lr=self.lr, betas=[0.9, 0.999], eps=1e-8)
                do_clip = False

            if use_fp16:
                static_loss_scale = 256
                static_loss_scale = "dynamic"
                # _, optimizer = amp.initialize([], optimizer, opt_level="O2", loss_scale=static_loss_scale)

            lr_scheduler = lr_cosine_policy(self.lr, 100, iterations_per_layer)

            for iteration_loc in range(iterations_per_layer):
                iteration += 1
                # learning rate scheduling
                lr_scheduler(optimizer, iteration_loc, iteration_loc)

                # perform downsampling if needed
                if lower_res != 1:
                    inputs_jit = pooling_function(inputs)
                else:
                    inputs_jit = inputs

                # apply random jitter offsets
                off1 = random.randint(-lim_0, lim_0)
                off2 = random.randint(-lim_1, lim_1)
                inputs_jit = torch.roll(inputs_jit, shifts=(off1, off2), dims=(2, 3))

                # Flipping
                flip = random.random() > 0.5
                if flip and self.do_flip:
                    inputs_jit = torch.flip(inputs_jit, dims=(3,))

                # forward pass
                optimizer.zero_grad()
                net_teacher.zero_grad()

                outputs = net_teacher(inputs_jit)
                outputs = self.network_output_function(outputs)

                # R_cross classification loss
                loss = criterion(outputs, targets)

                # R_prior losses
                loss_var_l1, loss_var_l2 = get_image_prior_losses(inputs_jit)

                # R_feature loss
                rescale = [self.first_bn_multiplier] + [1. for _ in range(len(self.loss_r_feature_layers) - 1)]
                loss_r_feature = sum(
                    [mod.r_feature * rescale[idx] for (idx, mod) in enumerate(self.loss_r_feature_layers)])

                # R_ADI
                loss_verifier_cig = torch.zeros(1)
                """
                if self.adi_scale != 0.0:
                    if self.detach_student:
                        outputs_student = net_student(inputs_jit).detach()
                    else:
                        outputs_student = net_student(inputs_jit)

                    T = 3.0
                    if 1:
                        T = 3.0
                        # Jensen Shanon divergence:
                        # another way to force KL between negative probabilities
                        P = nn.functional.softmax(outputs_student / T, dim=1)
                        Q = nn.functional.softmax(outputs / T, dim=1)
                        M = 0.5 * (P + Q)

                        P = torch.clamp(P, 0.01, 0.99)
                        Q = torch.clamp(Q, 0.01, 0.99)
                        M = torch.clamp(M, 0.01, 0.99)
                        eps = 0.0
                        loss_verifier_cig = 0.5 * kl_loss(torch.log(P + eps), M) + 0.5 * kl_loss(torch.log(Q + eps), M)
                        # JS criteria - 0 means full correlation, 1 - means completely different
                        loss_verifier_cig = 1.0 - torch.clamp(loss_verifier_cig, 0.0, 1.0)

                    if local_rank == 0:
                        if iteration % save_every == 0:
                            print('loss_verifier_cig', loss_verifier_cig.item())
                """
                # l2 loss on images
                loss_l2 = torch.norm(inputs_jit.view(self.bs, -1), dim=1).mean()

                # combining losses
                loss_aux = self.var_scale_l2 * loss_var_l2 + \
                           self.var_scale_l1 * loss_var_l1 + \
                           self.bn_reg_scale * loss_r_feature + \
                           self.l2_scale * loss_l2

                if self.adi_scale != 0.0:
                    loss_aux += self.adi_scale * loss_verifier_cig

                loss = self.main_loss_multiplier * loss + loss_aux

                if local_rank == 0:
                    if iteration % save_every == 0:
                        print("------------iteration {}----------".format(iteration))
                        print("total loss", loss.item())
                        print("loss_r_feature", loss_r_feature.item())
                        print("main criterion", criterion(outputs, targets).item())

                        if self.hook_for_display is not None:
                            self.hook_for_display(inputs, targets)

                # do image update
                if use_fp16:
                    pass
                    # optimizer.backward(loss)
                    # with amp.scale_loss(loss, optimizer) as scaled_loss:
                    #     scaled_loss.backward()
                else:
                    loss.backward()

                optimizer.step()

                # clip color outlayers
                if do_clip:
                    inputs.data = clip(inputs.data, use_fp16=use_fp16)

                if best_cost > loss.item() or iteration == 1:
                    best_inputs = inputs.data.clone()

                if iteration % save_every == 0 and (save_every > 0):
                    if local_rank == 0:
                        vutils.save_image(inputs,
                                          '{}/best_images/output_{:05d}_gpu_{}.png'.format(self.prefix,
                                                                                           iteration // save_every,
                                                                                           local_rank),
                                          normalize=True, scale_each=True, nrow=int(10))

        if self.store_best_images:
            best_inputs = denormalize(best_inputs)
            self.save_images(best_inputs, targets)

        # to reduce memory consumption by states of the optimizer we deallocate memory
        optimizer.state = collections.defaultdict(dict)
예제 #11
0
 def __init__(self, reduce='mean'):
     super(SoftTargetCrossEntropy, self).__init__()
     self.criterion = nn.KLDivLoss(reduction=reduce)
     self.reduce = reduce
예제 #12
0
            # print(newFolder, lastFolder)
            trim_frame_size = var.trim_frame_size
            utils = Helpers(test_folder)
            imu_training, imu_testing, training_target, testing_target = utils.load_datasets(
                args.reset_data, repeat=0)

            pipeline, model_checkpoint = models.get_model(
                args.model, test_folder)
            # pipeline.tensorboard_folder = args.tfolder
            optimizer = optim.Adam(pipeline.parameters(),
                                   lr=0.0001)  #, momentum=0.9)
            lambda1 = lambda epoch: 0.95**epoch
            scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                    lr_lambda=lambda1)
            criterion = nn.KLDivLoss(reduction='batchmean')
            gt_act = nn.Softmax2d()
            best_test_loss = -np.inf
            if Path(pipeline.var.root + 'datasets/' + test_folder[5:] + '/' +
                    model_checkpoint).is_file():
                checkpoint = torch.load(pipeline.var.root + 'datasets/' +
                                        test_folder[5:] + '/' +
                                        model_checkpoint)
                pipeline.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                best_test_acc = checkpoint['best_test_loss']
                print('Model loaded')

            os.chdir(pipeline.var.root)
            print(torch.cuda.device_count())
            best_test_loss = 0.0
예제 #13
0
def main():
    global args
    args = parse_args()
    args.input_dim, args.mem_dim = 300, 150
    args.hidden_dim, args.num_classes = 50, 5
    args.cuda = args.cuda and torch.cuda.is_available()
    print(args)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    train_dir = os.path.join(args.data,'train/')
    dev_dir = os.path.join(args.data,'dev/')
    test_dir = os.path.join(args.data,'test/')

    # write unique words from all token files
    token_files_a = [os.path.join(split,'a.toks') for split in [train_dir,dev_dir,test_dir]]
    token_files_b = [os.path.join(split,'b.toks') for split in [train_dir,dev_dir,test_dir]]
    token_files = token_files_a+token_files_b
    sick_vocab_file = os.path.join(args.data,'sick.vocab')
    build_vocab(token_files, sick_vocab_file)

    # get vocab object from vocab file previously written
    vocab = Vocab(filename=sick_vocab_file, data=[Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD])
    print('==> SICK vocabulary size : %d ' % vocab.size())

    # load SICK dataset splits
    train_file = os.path.join(args.data,'sick_train.pth')
    if os.path.isfile(train_file):
        train_dataset = torch.load(train_file)
    else:
        train_dataset = SICKDataset(train_dir, vocab, args.num_classes)
        torch.save(train_dataset, train_file)
    print('==> Size of train data   : %d ' % len(train_dataset))
    dev_file = os.path.join(args.data,'sick_dev.pth')
    if os.path.isfile(dev_file):
        dev_dataset = torch.load(dev_file)
    else:
        dev_dataset = SICKDataset(dev_dir, vocab, args.num_classes)
        torch.save(dev_dataset, dev_file)
    print('==> Size of dev data     : %d ' % len(dev_dataset))
    test_file = os.path.join(args.data,'sick_test.pth')
    if os.path.isfile(test_file):
        test_dataset = torch.load(test_file)
    else:
        test_dataset = SICKDataset(test_dir, vocab, args.num_classes)
        torch.save(test_dataset, test_file)
    print('==> Size of test data    : %d ' % len(test_dataset))

    # initialize model, criterion/loss_function, optimizer
    model = SimilarityTreeLSTM(
                args.cuda, vocab.size(),
                args.input_dim, args.mem_dim,
                args.hidden_dim, args.num_classes
            )
    criterion = nn.KLDivLoss()
    if args.cuda:
        model.cuda(), criterion.cuda()
    if args.optim=='adam':
        optimizer   = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    elif args.optim=='adagrad':
        optimizer   = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.wd)
    metrics = Metrics(args.num_classes)

    # for words common to dataset vocab and GLOVE, use GLOVE vectors
    # for other words in dataset vocab, use random normal vectors
    emb_file = os.path.join(args.data, 'sick_embed.pth')
    if os.path.isfile(emb_file):
        emb = torch.load(emb_file)
    else:
        # load glove embeddings and vocab
        glove_vocab, glove_emb = load_word_vectors(os.path.join(args.glove,'glove.840B.300d'))
        print('==> GLOVE vocabulary size: %d ' % glove_vocab.size())
        emb = torch.Tensor(vocab.size(),glove_emb.size(1)).normal_(-0.05,0.05)
        # zero out the embeddings for padding and other special words if they are absent in vocab
        for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]):
            emb[idx].zero_()
        for word in vocab.labelToIdx.keys():
            if glove_vocab.getIndex(word):
                emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)]
        torch.save(emb, emb_file)
    # plug these into embedding matrix inside model
    if args.cuda:
        emb = emb.cuda()
    model.childsumtreelstm.emb.state_dict()['weight'].copy_(emb)

    # create trainer object for training and testing
    trainer     = Trainer(args, model, criterion, optimizer)

    for epoch in range(args.epochs):
        train_loss             = trainer.train(train_dataset)
        train_loss, train_pred = trainer.test(train_dataset)
        dev_loss, dev_pred     = trainer.test(dev_dataset)
        test_loss, test_pred   = trainer.test(test_dataset)

        print('==> Train loss   : %f \t' % train_loss, end="")
        print('Train Pearson    : %f \t' % metrics.pearson(train_pred,train_dataset.labels), end="")
        print('Train MSE        : %f \t' % metrics.mse(train_pred,train_dataset.labels), end="\n")
        print('==> Dev loss     : %f \t' % dev_loss, end="")
        print('Dev Pearson      : %f \t' % metrics.pearson(dev_pred,dev_dataset.labels), end="")
        print('Dev MSE          : %f \t' % metrics.mse(dev_pred,dev_dataset.labels), end="\n")
        print('==> Test loss    : %f \t' % test_loss, end="")
        print('Test Pearson     : %f \t' % metrics.pearson(test_pred,test_dataset.labels), end="")
        print('Test MSE         : %f \t' % metrics.mse(test_pred,test_dataset.labels), end="\n")
예제 #14
0
def main():
    global args
    args = parse_args()
    # global logger
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s")
    # file logger
    fh = logging.FileHandler(os.path.join(args.save, args.expname)+'.log', mode='w')
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    # console logger
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    # argument validation
    args.cuda = args.cuda and torch.cuda.is_available()
    device = torch.device("cuda:0" if args.cuda else "cpu")
    if args.sparse and args.wd != 0:
        logger.error('Sparsity and weight decay are incompatible, pick one!')
        exit()
    logger.debug(args)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = True
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    train_dir = os.path.join(args.data, 'train/')
    dev_dir = os.path.join(args.data, 'dev/')
    test_dir = os.path.join(args.data, 'test/')

    # write unique words from all token files
    sick_vocab_file = os.path.join(args.data, 'sick.vocab')
    if not os.path.isfile(sick_vocab_file):
        token_files_b = [os.path.join(split, 'b.toks') for split in [train_dir, dev_dir, test_dir]]
        token_files_a = [os.path.join(split, 'a.toks') for split in [train_dir, dev_dir, test_dir]]
        token_files = token_files_a + token_files_b
        sick_vocab_file = os.path.join(args.data, 'sick.vocab')
        utils.build_vocab(token_files, sick_vocab_file)

    # get vocab object from vocab file previously written
    vocab = Vocab(filename=sick_vocab_file,
                  data=[Constants.PAD_WORD, Constants.UNK_WORD,
                        Constants.BOS_WORD, Constants.EOS_WORD])
    logger.debug('==> SICK vocabulary size : %d ' % vocab.size())

    # load SICK dataset splits
    train_file = os.path.join(args.data, 'sick_train.pth')
    if os.path.isfile(train_file):
        train_dataset = torch.load(train_file)
    else:
        train_dataset = SICKDataset(train_dir, vocab, args.num_classes)
        torch.save(train_dataset, train_file)
    logger.debug('==> Size of train data   : %d ' % len(train_dataset))
    dev_file = os.path.join(args.data, 'sick_dev.pth')
    if os.path.isfile(dev_file):
        dev_dataset = torch.load(dev_file)
    else:
        dev_dataset = SICKDataset(dev_dir, vocab, args.num_classes)
        torch.save(dev_dataset, dev_file)
    logger.debug('==> Size of dev data     : %d ' % len(dev_dataset))
    test_file = os.path.join(args.data, 'sick_test.pth')
    if os.path.isfile(test_file):
        test_dataset = torch.load(test_file)
    else:
        test_dataset = SICKDataset(test_dir, vocab, args.num_classes)
        torch.save(test_dataset, test_file)
    logger.debug('==> Size of test data    : %d ' % len(test_dataset))

    # initialize model, criterion/loss_function, optimizer
    model = Hybrid(
        vocab.size(),
        args.input_dim,
        args.mem_dim,
        args.hidden_dim,
        args.num_classes,
        args.sparse,
        args.freeze_embed)
    print("args.input_dim",args.input_dim)
    print("args.mem_dim",args.mem_dim)
    print("args.hidden_dim",args.hidden_dim)
    print("args.num_classes",args.num_classes)
    print("args.sparse",args.sparse)
    print("args.freeze_embed", args.freeze_embed)
    criterion = nn.KLDivLoss()

    # for words common to dataset vocab and GLOVE, use GLOVE vectors
    # for other words in dataset vocab, use random normal vectors
    emb_file = os.path.join(args.data, 'sick_embed.pth')
    if os.path.isfile(emb_file):
        emb = torch.load(emb_file)
    else:
        # load glove embeddings and vocab
        glove_vocab, glove_emb = utils.load_word_vectors(
            os.path.join(args.glove, 'glove.840B.300d'))
        logger.debug('==> GLOVE vocabulary size: %d ' % glove_vocab.size())
        print("glove_emb",glove.shape)
        emb = torch.zeros(vocab.size(), glove_emb.size(1), dtype=torch.float, device=device)
        emb.normal_(0, 0.05)
        # zero out the embeddings for padding and other special words if they are absent in vocab
        for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD,
                                    Constants.BOS_WORD, Constants.EOS_WORD]):
            emb[idx].zero_()
        for word in vocab.labelToIdx.keys():
            if glove_vocab.getIndex(word):
                emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)]
        torch.save(emb, emb_file)
    # plug these into embedding matrix inside model
    model.emb.weight.data.copy_(emb)

    model.to(device), criterion.to(device)
    if args.optim == 'adam':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()), lr=args.lr, weight_decay=args.wd)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                         model.parameters()), lr=args.lr, weight_decay=args.wd)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()), lr=args.lr, weight_decay=args.wd)
    metrics = Metrics(args.num_classes)

    # create trainer object for training and testing
    trainer = Trainer(args, model, criterion, optimizer, device)

    best = -float('inf')
    for epoch in range(args.epochs):
        train_loss = trainer.train(train_dataset)
        train_loss, train_pred = trainer.test(train_dataset)
        dev_loss, dev_pred = trainer.test(dev_dataset)
        test_loss, test_pred = trainer.test(test_dataset)

        train_pearson = metrics.pearson(train_pred, train_dataset.labels)
        train_mse = metrics.mse(train_pred, train_dataset.labels)
        logger.info('==> Epoch {}, Train \tLoss: {}\tPearson: {}\tMSE: {}'.format(
            epoch, train_loss, train_pearson, train_mse))
        dev_pearson = metrics.pearson(dev_pred, dev_dataset.labels)
        dev_mse = metrics.mse(dev_pred, dev_dataset.labels)
        logger.info('==> Epoch {}, Dev \tLoss: {}\tPearson: {}\tMSE: {}'.format(
            epoch, dev_loss, dev_pearson, dev_mse))
        test_pearson = metrics.pearson(test_pred, test_dataset.labels)
        test_mse = metrics.mse(test_pred, test_dataset.labels)
        logger.info('==> Epoch {}, Test \tLoss: {}\tPearson: {}\tMSE: {}'.format(
            epoch, test_loss, test_pearson, test_mse))

        if best < test_pearson:
            best = test_pearson
            checkpoint = {
                'model': trainer.model.state_dict(),
                'optim': trainer.optimizer,
                'pearson': test_pearson, 'mse': test_mse,
                'args': args, 'epoch': epoch
            }
            logger.debug('==> New optimum found, checkpointing everything now...')
            torch.save(checkpoint, '%s.pt' % os.path.join(args.save, args.expname))
예제 #15
0
def trades_loss(model,
                x_natural,
                y,
                optimizer,
                noise,
                mask,
                step_size=0.003,
                epsilon=0.031,
                perturb_steps=10,
                beta=6.0,
                distance='l_inf'):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)

    # generate adversarial example
    x_natural = x_natural.mul(mask)
    x_adv = x_natural.detach() + 0.001 * torch.randn(
        x_natural.shape).cuda().detach()
    x_adv = x_adv.mul(mask)
    noise = noise.mul(mask).detach()
    if distance == 'l_inf':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(
                    F.log_softmax(model(x_adv + noise), dim=1),
                    F.softmax(model(x_natural + noise), dim=1))
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon),
                              x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
            x_adv = x_adv.mul(mask)
    elif distance == 'l_2':
        delta = 0.001 * torch.randn(x_natural.shape).cuda().mul(mask).detach()
        delta = Variable(delta.data, requires_grad=True)

        # Setup optimizers
        optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)

        for _ in range(perturb_steps):
            adv = x_natural + delta

            # optimize
            optimizer_delta.zero_grad()
            with torch.enable_grad():
                loss = (-1) * criterion_kl(
                    F.log_softmax(model(adv + noise), dim=1),
                    F.softmax(model(x_natural + noise), dim=1))
            loss.backward()
            # renorming gradient
            grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
            delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
            # avoid nan or inf if gradient is 0
            if (grad_norms == 0).any():
                delta.grad[grad_norms == 0] = torch.randn_like(
                    delta.grad[grad_norms == 0])
            optimizer_delta.step()

            # projection
            delta.data.add_(x_natural)
            delta.data.clamp_(0, 1).sub_(x_natural)
            delta.data.mul(mask)
            delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
        x_adv = Variable(x_natural + delta, requires_grad=False)
    else:
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    logits = model(x_natural + noise)
    loss_natural = F.cross_entropy(logits, y)
    # loss_natural = nn.MultiMarginLoss()(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(
        F.log_softmax(model(x_adv + noise), dim=1),
        F.softmax(model(x_natural + noise), dim=1))
    loss = loss_natural + beta * loss_robust

    return loss
def train(args, train_dataset, model, tokenizer, teacher=None):
    """Train the model"""
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    # Added here for reproductibility
    set_seed(args)

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            if teacher is not None:
                teacher.eval()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }
            if args.model_type != "distilbert":
                inputs[
                    "token_type_ids"] = None if args.model_type == "xlm" else batch[
                        2]
            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
                if args.version_2_with_negative:
                    inputs.update({"is_impossible": batch[7]})
            outputs = model(**inputs)
            loss, start_logits_stu, end_logits_stu = outputs

            # Distillation loss
            if teacher is not None:
                if "token_type_ids" not in inputs:
                    inputs[
                        "token_type_ids"] = None if args.teacher_type == "xlm" else batch[
                            2]
                with torch.no_grad():
                    start_logits_tea, end_logits_tea = teacher(
                        input_ids=inputs["input_ids"],
                        token_type_ids=inputs["token_type_ids"],
                        attention_mask=inputs["attention_mask"],
                    )
                assert start_logits_tea.size() == start_logits_stu.size()
                assert end_logits_tea.size() == end_logits_stu.size()

                loss_fct = nn.KLDivLoss(reduction="batchmean")
                loss_start = (loss_fct(
                    nn.functional.log_softmax(
                        start_logits_stu / args.temperature, dim=-1),
                    nn.functional.softmax(start_logits_tea / args.temperature,
                                          dim=-1),
                ) * (args.temperature**2))
                loss_end = (loss_fct(
                    nn.functional.log_softmax(
                        end_logits_stu / args.temperature, dim=-1),
                    nn.functional.softmax(end_logits_tea / args.temperature,
                                          dim=-1),
                ) * (args.temperature**2))
                loss_ce = (loss_start + loss_end) / 2.0

                loss = args.alpha_ce * loss_ce + args.alpha_squad * loss

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                             args.max_grad_norm)
                else:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
예제 #17
0
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

bce = nn.BCELoss(reduction='none')
mse = nn.MSELoss()
kl_loss = nn.KLDivLoss(reduction='none')

def bce_loss(score, label, mode='sum'):
    raw_loss = bce(score, label)

    if mode == 'sum':
        return torch.sum(raw_loss)
    elif mode == 'average':
        return torch.mean(raw_loss)
    elif mode == 'raw':
        return torch.mean(raw_loss, dim=1)

    return raw_loss


def mse_loss(src, target):
    return mse(src, target)

def attention_loss(src, target):
    """
    
    Args:
        src: The attention score
        target: The output of Goal channel, goal score
예제 #18
0
    def Semi_SRL_Loss(self, hidden_forward, hidden_backward,
                      Predicate_idx_batch, unlabeled_sentence, TagProbs_use,
                      unlabeled_lengths):
        TagProbs_use_softmax = F.softmax(TagProbs_use, dim=2).detach()
        sample_nums = unlabeled_lengths.sum()
        unlabeled_loss_function = nn.KLDivLoss(reduce=False)
        ## Dependency Extractor FF
        hidden_states_word = self.dropout_1_FF(
            F.relu(self.Non_Predicate_Proj_FF(hidden_forward)))
        predicate_embeds = hidden_forward[np.arange(0,
                                                    hidden_forward.size()[0]),
                                          Predicate_idx_batch]
        hidden_states_predicate = self.dropout_2_FF(
            F.relu(self.Predicate_Proj_FF(predicate_embeds)))

        bias_one = torch.ones(
            (self.batch_size, len(unlabeled_sentence[0]), 1)).to(device)
        hidden_states_word = torch.cat(
            (hidden_states_word, Variable(bias_one)), 2)

        bias_one = torch.ones((self.batch_size, 1)).to(device)
        hidden_states_predicate = torch.cat(
            (hidden_states_predicate, Variable(bias_one)), 1)

        left_part = torch.mm(
            hidden_states_word.view(
                self.batch_size * len(unlabeled_sentence[0]), -1), self.W_R_FF)
        left_part = left_part.view(
            self.batch_size,
            len(unlabeled_sentence[0]) * self.tagset_size, -1)
        hidden_states_predicate = hidden_states_predicate.view(
            self.batch_size, -1, 1)
        tag_space = torch.bmm(left_part, hidden_states_predicate).view(
            self.batch_size, len(unlabeled_sentence[0]), -1)
        DEPprobs_student = F.log_softmax(tag_space, dim=2)
        DEP_FF_loss = unlabeled_loss_function(DEPprobs_student,
                                              TagProbs_use_softmax)

        ## Dependency Extractor BB
        hidden_states_word = self.dropout_1_BB(
            F.relu(self.Non_Predicate_Proj_BB(hidden_backward)))
        predicate_embeds = hidden_backward[np.arange(0,
                                                     hidden_forward.size()[0]),
                                           Predicate_idx_batch]
        hidden_states_predicate = self.dropout_2_BB(
            F.relu(self.Predicate_Proj_BB(predicate_embeds)))

        bias_one = torch.ones(
            (self.batch_size, len(unlabeled_sentence[0]), 1)).to(device)
        hidden_states_word = torch.cat(
            (hidden_states_word, Variable(bias_one)), 2)

        bias_one = torch.ones((self.batch_size, 1)).to(device)
        hidden_states_predicate = torch.cat(
            (hidden_states_predicate, Variable(bias_one)), 1)

        left_part = torch.mm(
            hidden_states_word.view(
                self.batch_size * len(unlabeled_sentence[0]), -1), self.W_R_BB)
        left_part = left_part.view(
            self.batch_size,
            len(unlabeled_sentence[0]) * self.tagset_size, -1)
        hidden_states_predicate = hidden_states_predicate.view(
            self.batch_size, -1, 1)
        tag_space = torch.bmm(left_part, hidden_states_predicate).view(
            self.batch_size, len(unlabeled_sentence[0]), -1)
        DEPprobs_student = F.log_softmax(tag_space, dim=2)
        DEP_BB_loss = unlabeled_loss_function(DEPprobs_student,
                                              TagProbs_use_softmax)

        ## Dependency Extractor FB
        hidden_states_word = self.dropout_1_FB(
            F.relu(self.Non_Predicate_Proj_FB(hidden_forward)))
        predicate_embeds = hidden_backward[np.arange(0,
                                                     hidden_forward.size()[0]),
                                           Predicate_idx_batch]
        hidden_states_predicate = self.dropout_2_FB(
            F.relu(self.Predicate_Proj_FB(predicate_embeds)))

        bias_one = torch.ones(
            (self.batch_size, len(unlabeled_sentence[0]), 1)).to(device)
        hidden_states_word = torch.cat(
            (hidden_states_word, Variable(bias_one)), 2)
        bias_one = torch.ones((self.batch_size, 1)).to(device)
        hidden_states_predicate = torch.cat(
            (hidden_states_predicate, Variable(bias_one)), 1)

        left_part = torch.mm(
            hidden_states_word.view(
                self.batch_size * len(unlabeled_sentence[0]), -1), self.W_R_FB)
        left_part = left_part.view(
            self.batch_size,
            len(unlabeled_sentence[0]) * self.tagset_size, -1)
        hidden_states_predicate = hidden_states_predicate.view(
            self.batch_size, -1, 1)
        tag_space = torch.bmm(left_part, hidden_states_predicate).view(
            self.batch_size, len(unlabeled_sentence[0]), -1)
        DEPprobs_student = F.log_softmax(tag_space, dim=2)
        DEP_FB_loss = unlabeled_loss_function(DEPprobs_student,
                                              TagProbs_use_softmax)

        ## Dependency Extractor BF
        hidden_states_word = self.dropout_1_BF(
            F.relu(self.Non_Predicate_Proj_BF(hidden_backward)))
        predicate_embeds = hidden_forward[np.arange(0,
                                                    hidden_forward.size()[0]),
                                          Predicate_idx_batch]
        hidden_states_predicate = self.dropout_2_BF(
            F.relu(self.Predicate_Proj_BF(predicate_embeds)))

        bias_one = torch.ones(
            (self.batch_size, len(unlabeled_sentence[0]), 1)).to(device)
        hidden_states_word = torch.cat(
            (hidden_states_word, Variable(bias_one)), 2)
        bias_one = torch.ones((self.batch_size, 1)).to(device)
        hidden_states_predicate = torch.cat(
            (hidden_states_predicate, Variable(bias_one)), 1)

        left_part = torch.mm(
            hidden_states_word.view(
                self.batch_size * len(unlabeled_sentence[0]), -1), self.W_R_BF)
        left_part = left_part.view(
            self.batch_size,
            len(unlabeled_sentence[0]) * self.tagset_size, -1)
        hidden_states_predicate = hidden_states_predicate.view(
            self.batch_size, -1, 1)
        tag_space = torch.bmm(left_part, hidden_states_predicate).view(
            self.batch_size, len(unlabeled_sentence[0]), -1)
        DEPprobs_student = F.log_softmax(tag_space, dim=2)
        DEP_BF_loss = unlabeled_loss_function(DEPprobs_student,
                                              TagProbs_use_softmax)

        DEP_Semi_loss = self.mask_loss(
            DEP_FF_loss + DEP_BB_loss + DEP_BF_loss + DEP_FB_loss,
            unlabeled_lengths)
        DEP_Semi_loss = torch.sum(DEP_Semi_loss)
        return DEP_Semi_loss / sample_nums
예제 #19
0
'''
  Idea of UDA is from: https://arxiv.org/abs/1904.12848
  The code is referred from: https://github.com/google-research/uda/tree/960684e363251772a5938451d4d2bc0f1da9e24b

  Note: The code is translated and reduced by our understanding from the paper
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

KL_loss = nn.KLDivLoss(reduction="none")
CrossEntropyLoss = nn.CrossEntropyLoss(reduction="none")


def get_threshold(current_step, total_step, tsa_type, num_classes):
    if tsa_type == "log":
        alpha = 1 - math.exp(-current_step / total_step * 5.)
    elif tsa_type == "linear":
        alpha = current_step / total_step
    else:
        alpha = math.exp((current_step / total_step - 1) * 5.)

    return alpha * (1.0 - 1.0 / num_classes) + 1.0 / num_classes


def torch_device_one():
    return torch.tensor(1.).to(_get_device())

 def __init__(self, reduction='mean'):
     super().__init__(reduction)
     self.kl_div_loss = nn.KLDivLoss(reduction=reduction)
예제 #21
0
def trades_loss(model,
                x_natural,
                y,
                optimizer,
                device,
                step_size=0.003,
                epsilon=0.031,
                perturb_steps=10,
                beta=1.0,
                distance='l_inf'):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach().to(device)
    if distance == 'l_inf':
        # logits_natural = model(x_natural).detach()

        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                       F.softmax(model(x_natural), dim=1))
                # loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                #                        F.softmax(logits_natural, dim=1))

            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)

    elif distance == 'l_2':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                       F.softmax(model(x_natural), dim=1))
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            for idx_batch in range(batch_size):
                grad_idx = grad[idx_batch]
                grad_idx_norm = l2_norm(grad_idx)
                grad_idx /= (grad_idx_norm + 1e-8)
                x_adv[idx_batch] = x_adv[idx_batch].detach() + step_size * grad_idx
                eta_x_adv = x_adv[idx_batch] - x_natural[idx_batch]
                norm_eta = l2_norm(eta_x_adv)
                if norm_eta > epsilon:
                    eta_x_adv = eta_x_adv * epsilon / l2_norm(eta_x_adv)
                x_adv[idx_batch] = x_natural[idx_batch] + eta_x_adv
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    else:
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    logits = model(x_natural)
    adv_logits = model(x_adv)
    loss_natural = F.cross_entropy(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(adv_logits, dim=1),
                                                    F.softmax(logits, dim=1))
    loss = loss_natural + beta * loss_robust

    cleanacc = torch_accuracy(logits, y, (1,))[0].item()
    tradesacc = torch_accuracy(adv_logits, y, (1,))[0].item()
    return loss, loss_natural.item(), loss_robust.item(), cleanacc, tradesacc
예제 #22
0
def train():
    criterion = nn.KLDivLoss(size_average=False)
    train_loss = np.zeros(opt.MAX_ITERATIONS + 1)
    results = []
    for iter_idx, (data, word_length, feature, answer, glove,
                   epoch) in enumerate(train_Loader):
        model.train()
        data = np.squeeze(data, axis=0)
        word_length = np.squeeze(word_length, axis=0)
        feature = np.squeeze(feature, axis=0)
        answer = np.squeeze(answer, axis=0)
        glove = np.squeeze(glove, axis=0)
        epoch = epoch.numpy()

        data = Variable(data).cuda().long()
        word_length = word_length.cuda()
        img_feature = Variable(feature).cuda().float()
        label = Variable(answer).cuda().float()
        glove = Variable(glove).cuda().float()

        optimizer.zero_grad()
        pred = model(data, word_length, img_feature, glove, 'train')
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()
        train_loss[iter_idx] = loss.data[0]
        if iter_idx % opt.DECAY_STEPS == 0 and iter_idx != 0:
            adjust_learning_rate(optimizer, opt.DECAY_RATE)
        if iter_idx % opt.PRINT_INTERVAL == 0 and iter_idx != 0:
            now = str(datetime.datetime.now())
            c_mean_loss = train_loss[iter_idx -
                                     opt.PRINT_INTERVAL:iter_idx].mean(
                                     ) / opt.BATCH_SIZE
            writer.add_scalar('mfh_coatt_glove/train_loss', c_mean_loss,
                              iter_idx)
            writer.add_scalar('mfh_coatt_glove/lr',
                              optimizer.param_groups[0]['lr'], iter_idx)
            print('{}\tTrain Epoch: {}\tIter: {}\tLoss: {:.4f}'.format(
                now, epoch, iter_idx, c_mean_loss))
        if iter_idx % opt.CHECKPOINT_INTERVAL == 0 and iter_idx != 0:
            if not os.path.exists('./data'):
                os.makedirs('./data')
            save_path = './data/mfh_coatt_glove_iter_' + str(iter_idx) + '.pth'
            torch.save(model.state_dict(), save_path)
        if iter_idx % opt.VAL_INTERVAL == 0 and iter_idx != 0:
            test_loss, acc_overall, acc_per_ques, acc_per_ans = exec_validation(
                model, opt, mode='val', folder=folder, it=iter_idx)
            writer.add_scalar('mfh_coatt_glove/val_loss', test_loss, iter_idx)
            writer.add_scalar('mfh_coatt_glove/accuracy', acc_overall,
                              iter_idx)
            print('Test loss:', test_loss)
            print('Accuracy:', acc_overall)
            print('Test per ans', acc_per_ans)
            results.append([
                iter_idx, c_mean_loss, test_loss, acc_overall, acc_per_ques,
                acc_per_ans
            ])
            best_result_idx = np.array([x[3] for x in results]).argmax()
            print('Best accuracy of', results[best_result_idx][3],
                  'was at iteration', results[best_result_idx][0])
            drawgraph(results,
                      folder,
                      opt.MFB_FACTOR_NUM,
                      opt.MFB_OUT_DIM,
                      prefix='mfh_coatt_glove')
        if iter_idx % opt.TESTDEV_INTERVAL == 0 and iter_idx != 0:
            exec_validation(model,
                            opt,
                            mode='test-dev',
                            folder=folder,
                            it=iter_idx)
예제 #23
0
def main():
    global args
    args = parse_args()
    # global logger
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s")
    # file logger
    fh = logging.FileHandler(os.path.join(args.save, args.expname) + '.log', mode='w')
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    # console logger
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    # argument validation
    args.cuda = args.cuda and torch.cuda.is_available()
    if args.sparse and args.wd != 0:
        logger.error('Sparsity and weight decay are incompatible, pick one!')
        exit()
    logger.debug(args)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = True
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    train_dir = os.path.join(args.data, 'train/')
    dev_dir = os.path.join(args.data, 'dev/')
    test_dir = os.path.join(args.data, 'test/')

    # write unique words from all token files
    dataset_vocab_file = os.path.join(args.data, 'dataset.vocab')
    if not os.path.isfile(dataset_vocab_file):
        token_files_a = [os.path.join(split, 'a.toks') for split in [train_dir, dev_dir, test_dir]]
        token_files_b = [os.path.join(split, 'b.toks') for split in [train_dir, dev_dir, test_dir]]
        token_files = token_files_a + token_files_b
        dataset_vocab_file = os.path.join(args.data, 'dataset.vocab')
        build_vocab(token_files, dataset_vocab_file)

    # get vocab object from vocab file previously written
    vocab = Vocab(filename=dataset_vocab_file,
                  data=[Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD])
    logger.debug('==> Dataset vocabulary size : %d ' % vocab.size())

    # load dataset splits
    train_file = os.path.join(args.data, 'dataset_train.pth')
    if os.path.isfile(train_file):
        train_dataset = torch.load(train_file)
    else:
        train_dataset = QGDataset(train_dir, vocab, args.num_classes)
        torch.save(train_dataset, train_file)
    logger.debug('==> Size of train data   : %d ' % len(train_dataset))
    dev_file = os.path.join(args.data, 'dataset_dev.pth')
    if os.path.isfile(dev_file):
        dev_dataset = torch.load(dev_file)
    else:
        dev_dataset = QGDataset(dev_dir, vocab, args.num_classes)
        torch.save(dev_dataset, dev_file)
    logger.debug('==> Size of dev data     : %d ' % len(dev_dataset))
    test_file = os.path.join(args.data, 'dataset_test.pth')
    if os.path.isfile(test_file):
        test_dataset = torch.load(test_file)
    else:
        test_dataset = QGDataset(test_dir, vocab, args.num_classes)
        torch.save(test_dataset, test_file)
    logger.debug('==> Size of test data    : %d ' % len(test_dataset))

    if args.sim == "cos":
        similarity = CosSimilarity(1)
    else:
        similarity = DASimilarity(args.mem_dim, args.hidden_dim, args.num_classes)

    # initialize model, criterion/loss_function, optimizer
    model = SimilarityTreeLSTM(
        vocab.size(),
        args.input_dim,
        args.mem_dim,
        similarity,
        args.sparse)
    criterion = nn.KLDivLoss()  # nn.HingeEmbeddingLoss()

    if args.cuda:
        model.cuda(), criterion.cuda()
    else:
        torch.set_num_threads(4)
    logger.info("number of available cores: {}".format(torch.get_num_threads()))
    if args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    elif args.optim == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.wd)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd)
    metrics = Metrics(args.num_classes)

    # for words common to dataset vocab and GLOVE, use GLOVE vectors
    # for other words in dataset vocab, use random normal vectors
    emb_file = os.path.join(args.data, 'dataset_embed.pth')
    if os.path.isfile(emb_file):
        emb = torch.load(emb_file)
    else:
        # load glove embeddings and vocab
        glove_vocab, glove_emb = load_word_vectors(os.path.join(args.glove, 'glove.840B.300d'))
        logger.debug('==> GLOVE vocabulary size: %d ' % glove_vocab.size())
        emb = torch.Tensor(vocab.size(), glove_emb.size(1)).normal_(-0.05, 0.05)
        # zero out the embeddings for padding and other special words if they are absent in vocab
        for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]):
            emb[idx].zero_()
        for word in vocab.labelToIdx.keys():
            if glove_vocab.getIndex(word):
                emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)]
        torch.save(emb, emb_file)
    # plug these into embedding matrix inside model
    if args.cuda:
        emb = emb.cuda()
    model.emb.weight.data.copy_(emb)

    checkpoint_filename = '%s.pt' % os.path.join(args.save, args.expname)
    if args.mode == "test":
        checkpoint = torch.load(checkpoint_filename)
        model.load_state_dict(checkpoint['model'])
        args.epochs = 1

    # create trainer object for training and testing
    trainer = Trainer(args, model, criterion, optimizer)

    for epoch in range(args.epochs):
        if args.mode == "train":
            train_loss = trainer.train(train_dataset)
            train_loss, train_pred = trainer.test(train_dataset)
            logger.info(
                '==> Epoch {}, Train \tLoss: {} {}'.format(epoch, train_loss,
                                                           metrics.all(train_pred, train_dataset.labels)))
            checkpoint = {'model': trainer.model.state_dict(), 'optim': trainer.optimizer,
                          'args': args, 'epoch': epoch}
            torch.save(checkpoint, checkpoint_filename)

        dev_loss, dev_pred = trainer.test(dev_dataset)
        test_loss, test_pred = trainer.test(test_dataset)
        logger.info(
            '==> Epoch {}, Dev \tLoss: {} {}'.format(epoch, dev_loss, metrics.all(dev_pred, dev_dataset.labels)))
        logger.info(
            '==> Epoch {}, Test \tLoss: {} {}'.format(epoch, test_loss, metrics.all(test_pred, test_dataset.labels)))
def KLDivLoss(Stu_output, Tea_output, temperature = 1):
    T = temperature
    KD_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(Stu_output/T, dim=1), F.softmax(Tea_output/T, dim=1))
    KD_loss = KD_loss * T * T
    return KD_loss
예제 #25
0
def trades_loss(
        model,
        x_natural,
        y,
        device,
        optimizer,
        step_size,
        epsilon,
        perturb_steps,
        beta,
        clip_min,
        clip_max,
        distance="l_inf",
        natural_criterion=nn.CrossEntropyLoss(),
):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    x_adv = (x_natural.detach() +
             0.001 * torch.randn(x_natural.shape).to(device).detach())
    if distance == "l_inf":
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(
                    F.log_softmax(model(x_adv), dim=1),
                    F.softmax(model(x_natural), dim=1),
                )
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon),
                              x_natural + epsilon)
            x_adv = torch.clamp(x_adv, clip_min, clip_max)
    elif distance == "l_2":
        delta = 0.001 * torch.randn(x_natural.shape).to(device).detach()
        delta = Variable(delta.data, requires_grad=True)

        # Setup optimizers
        optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)

        for _ in range(perturb_steps):
            adv = x_natural + delta

            # optimize
            optimizer_delta.zero_grad()
            with torch.enable_grad():
                loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1),
                                           F.softmax(model(x_natural), dim=1))
            loss.backward()
            # renorming gradient
            grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
            delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
            # avoid nan or inf if gradient is 0
            if (grad_norms == 0).any():
                delta.grad[grad_norms == 0] = torch.randn_like(
                    delta.grad[grad_norms == 0])
            optimizer_delta.step()

            # projection
            delta.data.add_(x_natural)
            delta.data.clamp_(clip_min, clip_max).sub_(x_natural)
            delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
        x_adv = Variable(x_natural + delta, requires_grad=False)
    else:
        x_adv = torch.clamp(x_adv, clip_min, clip_max)
    model.train()

    x_adv = Variable(torch.clamp(x_adv, clip_min, clip_max),
                     requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    logits = model(x_natural)
    loss_natural = natural_criterion(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(
        F.log_softmax(model(x_adv), dim=1), F.softmax(model(x_natural), dim=1))
    loss = loss_natural + beta * loss_robust
    return loss
예제 #26
0
def main():
    data_holder, task2id, id2task, num_feat, num_voc, num_char, tgt_dict, embeddings = DataLoader.multitask_dataloader(
        pkl_path, num_task=num_task, batch_size=BATCH_SIZE)
    para = model_para
    task2label = {"conll2000": "chunk", "unidep": "POS", "conll2003": "NER"}
    #task2label = {"conll2000": "chunk", "wsjpos": "POS", "conll2003": "NER"}
    #logger = Logger('./logs/'+str(args.gpu))
    para["id2task"] = id2task
    para["n_feats"] = num_feat
    para["n_vocs"] = num_voc
    para["n_tasks"] = num_task
    para["out_size"] = [
        len(tgt_dict[task2label[id2task[ids]]]) for ids in range(num_task)
    ]
    para["n_chars"] = num_char
    model = Model_crf2.build_model_cnn(para)
    model.Word_embeddings.apply_weights(embeddings)

    params = list(filter(lambda p: p.requires_grad, model.parameters()))
    num_params = sum(p.numel() for p in model.parameters())
    print(model)
    print("Num of paras:", num_params)
    print(model.concat_flag)

    def lr_decay(optimizer, epoch, decay_rate=0.9, init_lr=0.015):
        lr = init_lr / (1 + decay_rate * epoch)
        print(" Learning rate is set as:", lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        return optimizer

    def exp_lr_decay(optimizer, epoch, decay_rate=0.05, init_lr=0.015):
        lr = init_lr * decay_rate**epoch
        print(" Learning rate is set as:", lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        return optimizer

    if args.optim == "noam":
        model_optim = optim_custorm.NoamOpt(
            para["d_hid"], 1, 1000,
            torch.optim.Adam(params,
                             lr=0.0,
                             betas=(0.9, 0.98),
                             eps=1e-9,
                             weight_decay=L2))
        args.decay = None
    elif args.optim == "sgd":
        model_optim = optim.SGD(params,
                                lr=0.015,
                                momentum=args.momentum,
                                weight_decay=1e-8)

    if args.mode == "train":
        best_F1 = 0
        if not para["crf"]:
            calculate_loss = nn.NLLLoss()
        else:
            calculate_loss = [
                CRFLoss_vb(
                    len(tgt_dict[task2label[id2task[idx]]]) + 2,
                    len(tgt_dict[task2label[id2task[idx]]]),
                    len(tgt_dict[task2label[id2task[idx]]]) + 1)
                for idx in range(num_task)
            ]
            if USE_CUDA:
                for x in calculate_loss:
                    x = x.cuda()
        print("Start training...")
        print('-' * 60)
        KLLoss = nn.KLDivLoss()
        start_point = time.time()
        for epoch_idx in range(NUM_EPOCH):
            if args.decay == "exp":
                model_optim = exp_lr_decay(model_optim, epoch_idx)
            elif args.decay == "normal":
                model_optim = lr_decay(model_optim, epoch_idx)
            Pre, Rec, F1, loss_list = run_epoch(model, data_holder,
                                                model_optim, calculate_loss,
                                                KLLoss, para, epoch_idx,
                                                id2task)

            use_time = time.time() - start_point

            if num_task == 3:

                if loss_list[task2id["conll2003"]] < min(
                        loss_list[task2id["conll2000"]],
                        loss_list[task2id["unidep"]]):
                    args.weight = args.weight * max(
                        loss_list[task2id["conll2000"]],
                        loss_list[task2id["unidep"]]) / loss_list[
                            task2id["conll2003"]]
            print("Change weight to %f at epoch_idx %d:" %
                  (args.weight, epoch_idx))
            print("Time using: %f mins" % (use_time / 60))
            if not best_F1 or best_F1 < F1:
                best_F1 = F1
                Model_crf2.save_model(model_path, model, para)
                print('*' * 60)
                print(
                    "Save model with average Pre: %f, Rec: %f, F1: %f on dev set."
                    % (Pre, Rec, F1))
                save_idx = epoch_idx
                print('*' * 60)
        print("save model at epoch:", save_idx)

    else:
        para_path = os.path.join(path, 'para.pkl')
        with open(para_path, "wb") as f:
            para_save = pickle.load(f)
        model = Model_crf2.build_model(para_save)
        model = Model_crf2.read_model(model_path, model)
        prec_list, rec_list, f1_list = infer(model, data_holder, "test")
예제 #27
0
 def __init__(self, option_map):
     self.policy_loss = nn.KLDivLoss().cuda()
     self.value_loss = nn.MSELoss().cuda()
     self.logger = _logger_factory.makeLogger('elfgames.go.MCTSPrediction-',
                                              '')
     self.timer = RLTimer()
예제 #28
0
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.dataset)
            self.num_valid = len(self.valid_loader.dataset)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = config.num_classes

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr
        self.weight_decay = config.weight_decay
        self.nesterov = config.nesterov
        self.gamma = config.gamma
        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.model_name = config.save_name

        self.model_num = config.model_num
        self.models = []
        self.optimizers = []
        self.schedulers = []

        self.loss_kl = nn.KLDivLoss(reduction='batchmean')
        self.loss_ce = nn.CrossEntropyLoss()
        self.best_valid_accs = [0.] * self.model_num

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        for i in range(self.model_num):
            # build models
            model = resnet32()
            if self.use_gpu:
                model.cuda()

            self.models.append(model)

            # initialize optimizer and scheduler
            optimizer = optim.SGD(model.parameters(),
                                  lr=self.lr,
                                  momentum=self.momentum,
                                  weight_decay=self.weight_decay,
                                  nesterov=self.nesterov)

            self.optimizers.append(optimizer)

            # set learning rate decay
            scheduler = optim.lr_scheduler.StepLR(self.optimizers[i],
                                                  step_size=60,
                                                  gamma=self.gamma,
                                                  last_epoch=-1)
            self.schedulers.append(scheduler)

        print('[*] Number of parameters of one model: {:,}'.format(
            sum([p.data.nelement() for p in self.models[0].parameters()])))
예제 #29
0
def robust_train_one_epoch(model,
                           optimizer,
                           loader_train,
                           args,
                           eps,
                           delta,
                           epoch,
                           training_output_dir_name,
                           verbose=True):
    print('Current eps: {}, delta: {}'.format(eps, delta))
    losses = []
    losses_ben = []
    model.train()
    if 'hybrid' in args.attack:
        training_time = True
        trainset, testset, data_details = load_dataset_tensor(
            args, data_dir='data', training_time=training_time)
        print('Data loaded for hybrid attack of len {}'.format(len(trainset)))
    if 'KL' in args.loss_fn:
        opt_prob_dir = 'graph_data/optimal_probs/'
        opt_fname = 'logloss_' + str(class_1) + '_' + str(class_2) + '_' + str(
            args.num_samples
        ) + '_' + args.dataset_in + '_' + args.norm + '_' + str(
            args.epsilon) + '.txt'
        optimal_scores_overall = np.loadtxt(opt_prob_dir + opt_fname)
    for t, (x, y, idx, ez, m) in enumerate(loader_train):
        x = x.cuda()
        y = y.cuda()
        if args.loss_fn == 'trades':
            loss, loss_ben, loss_adv = trades_loss(model,
                                                   x,
                                                   y,
                                                   optimizer,
                                                   delta,
                                                   eps,
                                                   args.attack_iter,
                                                   args.gamma,
                                                   beta=args.beta,
                                                   distance=args.attack)
            losses.append(loss_adv.data.cpu().numpy())
            losses_ben.append(loss_ben.data.cpu().numpy())
        elif args.loss_fn == 'KL_emp':
            _, loss_ben, loss_adv = trades_loss(model,
                                                x,
                                                y,
                                                optimizer,
                                                delta,
                                                eps,
                                                args.attack_iter,
                                                args.gamma,
                                                beta=1.0,
                                                distance=args.attack)
            losses.append(loss_adv.data.cpu().numpy())
            losses_ben.append(loss_ben.data.cpu().numpy())
            loss = loss_adv
        else:
            x_mod = None
            if 'KL' in args.loss_fn:
                optimal_scores = torch.from_numpy(
                    optimal_scores_overall[idx]).float().cuda()
            else:
                optimal_scores = None
            if 'hybrid' in args.attack:
                # Find matched and unmatched data and labels
                unmatched_x = x[ez]
                unmatched_y = y[ez]
                matched_x = x[~ez]
                matched_y = y[~ez]
                if 'seed' in args.attack:
                    if len(m[~ez] > 0):
                        # print('Performing hybrid attack')
                        x_mod = hybrid_attack(matched_x, ez, m, trainset, eps)
                elif 'replace' in args.attack:
                    # Only construct adv. examples for unmatched
                    x = x[ez]
                    y = y[ez]
            x_var = Variable(x, requires_grad=True)
            y_var = Variable(y, requires_grad=False)
            if args.targeted:
                y_target = generate_target_label_tensor(y_var.cpu(),
                                                        args).cuda()
            else:
                y_target = y_var
            if 'PGD_linf' in args.attack:
                adv_x = pgd_attack(model, x, x_var, y_target, args.attack_iter,
                                   eps, delta, args.clip_min, args.clip_max,
                                   args.targeted, args.rand_init)
            elif 'PGD_l2' in args.attack:
                adv_x = pgd_l2_attack(model, x, x_var, y_target,
                                      args.attack_iter, eps, delta,
                                      args.clip_min, args.clip_max,
                                      args.targeted, args.rand_init,
                                      args.num_restarts, x_mod, ez,
                                      args.attack_loss, optimal_scores)
            if 'hybrid' in args.attack:
                x = torch.cat((unmatched_x, matched_x))
                y = torch.cat((unmatched_y, matched_y))
                y_var = Variable(y, requires_grad=False)
                if 'replace' in args.attack:
                    x_mod = hybrid_attack(matched_x, ez, m, rel_data,
                                          args.new_epsilon)
                    adv_x = torch.cat((adv_x, x_mod))
            scores = model(adv_x)
            ben_loss_function = nn.CrossEntropyLoss(reduction='none')
            batch_loss_ben = ben_loss_function(model(x), y_var)
            if args.loss_fn == 'CE':
                loss_function = nn.CrossEntropyLoss(reduction='none')
                batch_loss_adv = loss_function(scores, y_var)
                loss = torch.mean(batch_loss_adv)
            elif args.loss_fn == 'KL':
                # optimal_scores = torch.from_numpy(optimal_scores_overall[idx]).float().cuda()
                loss_function = nn.KLDivLoss(reduction='none')
                batch_loss_adv = loss_function(scores, optimal_scores)
                loss = torch.mean(batch_loss_adv)
                # print(loss.shape)
            elif args.loss_fn == 'KL_flat':
                # optimal_scores = torch.from_numpy(optimal_scores_overall[idx]).float().cuda()
                batch_loss_adv = KL_loss_flat(scores, optimal_scores, y_var, t)
                loss = torch.mean(batch_loss_adv)
            elif args.loss_fn == 'trades_opt':
                adv_loss_function = nn.KLDivLoss(reduction='none')
                ben_loss_function = nn.CrossEntropyLoss(reduction='none')
                batch_loss_adv = adv_loss_function(scores, optimal_scores)
                batch_loss_ben = ben_loss_function(model(x), optimal_scores)
                loss = torch.mean(batch_loss_ben + batch_loss_adv)
            loss_ben = torch.mean(batch_loss_ben)
            losses_ben.append(loss_ben.data.cpu().numpy())
            losses.append(loss.data.cpu().numpy())
        # GD step
        optimizer.zero_grad()
        loss.backward()
        # print(model.conv1.weight.grad)
        optimizer.step()
        if verbose:
            print('loss = %.8f' % (loss.data))
    return np.mean(losses), np.mean(losses_ben)
def get_images(net,
               net_name,
               bs=256,
               i_time=0,
               epochs=1000,
               idx=-1,
               var_scale=0.00005,
               net_student=None,
               prefix=None,
               competitive_scale=0.00,
               train_writer=None,
               global_iteration=None,
               use_amp=False,
               optimizer=None,
               inputs=None,
               bn_reg_scale=0.0,
               random_labels=False,
               l2_coeff=0.0):
    '''
    Function returns inverted images from the pretrained model, parameters are tight to CIFAR dataset
    args in:
        net: network to be inverted
        bs: batch size
        epochs: total number of iterations to generate inverted images, training longer helps a lot!
        idx: an external flag for printing purposes: only print in the first round, set as -1 to disable
        var_scale: the scaling factor for variance loss regularization. this may vary depending on bs
            larger - more blurred but less noise
        net_student: model to be used for Adaptive DeepInversion
        prefix: defines the path to store images
        competitive_scale: coefficient for Adaptive DeepInversion
        train_writer: tensorboardX object to store intermediate losses
        global_iteration: indexer to be used for tensorboard
        use_amp: boolean to indicate usage of APEX AMP for FP16 calculations - twice faster and less memory on TensorCores
        optimizer: potimizer to be used for model inversion
        inputs: data place holder for optimization, will be reinitialized to noise
        bn_reg_scale: weight for r_feature_regularization
        random_labels: sample labels from random distribution or use columns of the same class
        l2_coeff: coefficient for L2 loss on input
    return:
        A tensor on GPU with shape (bs, 3, 32, 32) for CIFAR
    '''

    kl_loss = nn.KLDivLoss(reduction='batchmean').cuda()

    # preventing backpropagation through student for Adaptive DeepInversion
    # net_student.eval()

    best_cost = 1e6

    # initialize gaussian inputs
    inputs.data = torch.randn((bs, 3, 32, 32),
                              requires_grad=True,
                              device='cuda')
    # if use_amp:
    #     inputs.data = inputs.data.half()

    # set up criteria for optimization
    criterion = nn.CrossEntropyLoss()

    optimizer.state = collections.defaultdict(dict)  # Reset state of optimizer

    # target outputs to generate
    if random_labels:
        targets = torch.LongTensor([random.randint(0, 9)
                                    for _ in range(bs)]).to('cuda')
    else:
        targets = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 25 +
                                   [0, 1, 2, 3, 4, 5]).to('cuda')

    ## Create hooks for feature statistics catching
    loss_r_feature_layers = []
    for module in net.modules():
        if isinstance(module, nn.BatchNorm2d):
            loss_r_feature_layers.append(DeepInversionFeatureHook(module))

    # setting up the range for jitter
    lim_0, lim_1 = 2, 2
    print('number of BN layers = %d' % len(loss_r_feature_layers))

    name_use = "./best_images"
    os.makedirs(name_use, exist_ok=True)
    os.makedirs("./training", exist_ok=True)
    start_time = time.time()
    for epoch in range(epochs):
        # apply random jitter offsets
        off1 = random.randint(-lim_0, lim_0)
        off2 = random.randint(-lim_1, lim_1)
        inputs_jit = torch.roll(inputs, shifts=(off1, off2), dims=(2, 3))

        # foward with jit images
        optimizer.zero_grad()
        net.zero_grad()
        outputs = net(inputs_jit)
        loss = criterion(outputs, targets)
        loss_target = loss.item()

        # competition loss, Adaptive DeepInvesrion
        if competitive_scale != 0.0:
            net_student.zero_grad()
            outputs_student = net_student(inputs_jit)
            T = 3.0

            if 1:
                # jensen shanon divergence:
                # another way to force KL between negative probabilities
                P = F.softmax(outputs_student / T, dim=1)
                Q = F.softmax(outputs / T, dim=1)
                M = 0.5 * (P + Q)

                P = torch.clamp(P, 0.01, 0.99)
                Q = torch.clamp(Q, 0.01, 0.99)
                M = torch.clamp(M, 0.01, 0.99)
                eps = 0.0
                # loss_verifier_cig = 0.5 * kl_loss(F.log_softmax(outputs_verifier / T, dim=1), M) +  0.5 * kl_loss(F.log_softmax(outputs/T, dim=1), M)
                loss_verifier_cig = 0.5 * kl_loss(torch.log(
                    P + eps), M) + 0.5 * kl_loss(torch.log(Q + eps), M)
                # JS criteria - 0 means full correlation, 1 - means completely different
                loss_verifier_cig = 1.0 - torch.clamp(loss_verifier_cig, 0.0,
                                                      1.0)

                loss = loss + competitive_scale * loss_verifier_cig

        # apply total variation regularization
        diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:]
        diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :]
        diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:]
        diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:]
        loss_var = torch.norm(diff1) + torch.norm(diff2) + torch.norm(
            diff3) + torch.norm(diff4)
        loss = loss + var_scale * loss_var

        # R_feature loss
        loss_distr = sum([mod.r_feature for mod in loss_r_feature_layers])
        loss = loss + bn_reg_scale * loss_distr  # best for noise before BN

        # l2 loss
        if 1:
            loss = loss + l2_coeff * torch.norm(inputs_jit, 2)

        if debug_output and epoch % 200 == 0:
            print(
                f"It {epoch}\t Losses: total: {loss.item():3.3f},\ttarget: {loss_target:3.3f} \tR_feature_loss unscaled:\t {loss_distr.item():3.3f}"
            )
        if epoch % 100 == 0 and epoch > 0:
            for ids in range(bs):
                save_image(
                    inputs.data[ids].clone(),
                    './training/{}/{}/{}.png'.format(net_name, epoch,
                                                     bs * i_time + ids))

        if best_cost > loss.item():
            best_cost = loss.item()
            best_inputs = inputs.data

        # backward pass
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()
    print("--- %s seconds ---" % (time.time() - start_time))
    outputs = net(best_inputs)
    _, predicted_teach = outputs.max(1)

    # outputs_student=net_student(best_inputs)
    # _, predicted_std = outputs_student.max(1)

    if idx == 0:
        print('Teacher correct out of {}: {}, loss at {}'.format(
            bs,
            predicted_teach.eq(targets).sum().item(),
            criterion(outputs, targets).item()))
        # print('Student correct out of {}: {}, loss at {}'.format(bs, predicted_std.eq(targets).sum().item(), criterion(outputs_student, targets).item()))

    # if prefix is not None:
    #     name_use = prefix + name_use
    # next_batch = len(glob.glob("./%s/*.png" % name_use)) // 1

    vutils.save_image(best_inputs[:20].clone(),
                      './{}/{}_output.png'.format(name_use, net_name),
                      normalize=True,
                      scale_each=True,
                      nrow=10)

    if train_writer is not None:
        train_writer.add_scalar('gener_teacher_criteria',
                                criterion(outputs, targets), global_iteration)
        # train_writer.add_scalar('gener_student_criteria', criterion(outputs_student, targets), global_iteration)

        train_writer.add_scalar('gener_teacher_acc',
                                predicted_teach.eq(targets).sum().item() / bs,
                                global_iteration)
        # train_writer.add_scalar('gener_student_acc', predicted_std.eq(targets).sum().item() / bs, global_iteration)

        train_writer.add_scalar('gener_loss_total', loss.item(),
                                global_iteration)
        train_writer.add_scalar('gener_loss_var',
                                (var_scale * loss_var).item(),
                                global_iteration)

    # net_student.train()

    return best_inputs