Пример #1
0
def cutmix_pass(net, criterion, inputs, targets, alpha, max_lambda=True):
    mixup_sampler = beta.Beta(alpha, alpha)
    lam = mixup_sampler.sample().to(inputs.device)
    if max_lambda:
        lam = torch.min(lam, 1 - lam)

    inputs, targets_a, targets_b, lam = cutmix(inputs, targets, lam)

    outputs = net(inputs)
    loss_orig = criterion(outputs, targets)
    loss_mixed = criterion(outputs, targets_b)

    loss = lam * loss_orig + (1 - lam) * loss_mixed
    return loss, outputs
Пример #2
0
def mixup_pass(net, criterion, inputs, targets, alpha, max_lambda=False):
    mixup_sampler = beta.Beta(alpha, alpha)

    lam = mixup_sampler.sample().to(inputs.device)
    if max_lambda:
        lam = torch.max(lam, 1 - lam)

    mixed_inputs, shuffled_targets = mixup(inputs, targets, lam)

    outputs = net(inputs)
    outputs_mixed = net(mixed_inputs)

    loss_orig = criterion(outputs, targets)
    loss_mixed = criterion(outputs_mixed, shuffled_targets)

    loss = lam * loss_orig + (1 - lam) * loss_mixed
    return loss
Пример #3
0
    def __init__(self, args):
        self.args = args
        self.config, self.output_dir, self.logger, self.device = common.init_experiment(
            args)

        # Initiate model, optimizer and scheduler
        assert self.config['model']['name'] in NETWORKS.keys(
        ), f"Unrecognized model name {self.config['model']['name']}"
        self.model = NETWORKS[self.config['model']['name']]['net'](
            pretrained=self.config['model']['pretrained']).to(self.device)
        self.optim = train_utils.get_optimizer(self.config['optimizer'],
                                               self.model.parameters())
        self.scheduler, self.warmup_epochs = train_utils.get_scheduler(
            {
                **self.config['scheduler'], "epochs": self.config["epochs"]
            }, self.optim)

        if self.warmup_epochs > 0:
            self.warmup_rate = (self.config['optimizer']['lr'] -
                                1e-12) / self.warmup_epochs

        # Dataloaders
        self.train_loader, self.val_loader, self.test_loader = data_utils.get_dataloaders(
            train_root=self.config['data']['train_root'],
            test_root=self.config['data']['test_root'],
            transforms=self.config['data']['transforms'],
            val_split=self.config['data']['val_split'],
            batch_size=self.config['data']['batch_size'])
        self.beta_dist = beta.Beta(self.config['data'].get("alpha", 0.3),
                                   self.config['data'].get("alpha", 0.3))
        self.batch_size = self.config['data']['batch_size']

        # Logging and model saving
        self.criterion = losses.LogLoss()
        self.best_val_loss = np.inf
        self.done_epochs = 0

        # Wandb
        run = wandb.init(project='deepfake-dl-hack')
        self.logger.write(f"Wandb: {run.get_url()}", mode='info')

        # Load model
        if args['load'] is not None:
            self.load_model(args['load'])
Пример #4
0
def train(i_epoch, network, criterionA, criterionB, optimizer, dataloader,
          device, memory_bank, ramp_up):
    #all_targets = np.array(dataloader.dataset.targets)
    network.train()
    losses_ins = AvgMeter()
    losses_inv = AvgMeter()
    losses_mix = AvgMeter()
    losses = AvgMeter()
    beta_dist = beta.Beta(0.75, 0.75)
    all_weights = []
    n_neighbour = AvgMeter()
    pbar = tqdm(dataloader)

    ipacc = AverageMeter()
    nnacc = AverageMeter()
    for data in pbar:
        img = data[1].to(device)
        # normal_img = img[:,0,:,:,:]
        index = data[0].to(device)
        output = network(img).to(device)

        # Nearst Neighbour Set vs Invariance Propagation Set
        L_ins, L_inv, NNS, IPS = criterionA(output, index, memory_bank)
        # lossA = lossA_1 + args.lam_inv * lossA_2
        if np.random.rand() < 0.0:
            # NNS BSx4096 IPS BSxK index
            for i_sample in range(NNS.size(0)):
                right_target = all_targets[index[i_sample]]
                this_ips = np.unique(IPS[i_sample].detach().cpu().numpy())
                iptargets = all_targets[this_ips]
                ip_consistency = (iptargets == right_target).sum() / float(
                    len(this_ips))
                ipacc.update(ip_consistency, len(this_ips))

                this_nns = NNS[i_sample].detach().cpu().numpy(
                )[:len(this_ips) + 1]
                nntargets = all_targets[this_nns]
                nn_consistency = ((nntargets == right_target).sum() -
                                  1) / float(len(this_ips))
                nnacc.update(nn_consistency, len(this_ips))

        if args.mix:
            permutations = np.arange(index.size(0))
            np.random.shuffle(permutations)
            imgB = img[permutations]
            indexB = index[permutations]
            Alphas = beta_dist.sample([
                index.size(0),
            ]).to(device)
            MixImgs = img * Alphas.view(
                -1, 1, 1, 1) + imgB * (1 - Alphas).view(-1, 1, 1, 1)
            outputMix = network(MixImgs)

            L_mix = criterionB(outputMix, Alphas, index, indexB, memory_bank)
        else:
            L_mix = 0.0

        L = L_ins + args.lam_inv * ramp_up(
            i_epoch) * L_inv + args.lam_mix * L_mix
        losses_ins.add(L_ins.item())
        losses_inv.add(0.0 if type(L_inv) == float else L_inv.item())
        losses_mix.add(0.0 if type(L_mix) == float else L_mix.item())
        losses.add(L.item())

        optimizer.zero_grad()
        L.backward()
        optimizer.step()

        with torch.no_grad():
            memory_bank.update_points(output.detach(), index)

        lr = optimizer.param_groups[0]['lr']
        pbar.set_description("Epoch:{} [lr:{}]".format(i_epoch, lr))
        info = 'L: {:.4f} = L_ins: {:.4f} + {:.3f} * L_inv: {:.4f} + {:.3f} * L_mix: {:.4f}'.format(
            losses.get(), losses_ins.get(), args.lam_inv * ramp_up(i_epoch),
            losses_inv.get(), args.lam_mix, losses_mix.get())
        pbar.set_postfix(info=info)

    writer.add_scalar('L', losses.get(), i_epoch)
    writer.add_scalar('L_ins', losses_ins.get(), i_epoch)
    writer.add_scalar('L_inv', losses_inv.get(), i_epoch)
    writer.add_scalar('L_mix', losses_mix.get(), i_epoch)
    logging.info('Epoch {}: L: {:.4f}'.format(i_epoch, losses.get()))
    logging.info('Epoch {}: L_ins: {:.4f}'.format(i_epoch, losses_ins.get()))
    logging.info('Epoch {}: L_inv: {:.4f}'.format(i_epoch, losses_inv.get()))
    logging.info('Epoch {}: L_mix: {:.4f}'.format(i_epoch, losses_mix.get()))
Пример #5
0
def beta_loss(x, alpha, beta1):
    v = b1.Beta(alpha, beta1)
    return v.log_prob(x)
Пример #6
0
def sample_beta_dist(alpha, beta):
    return beta_dist.Beta(alpha, beta).sample()
Пример #7
0
 def __init__(self, betaparam=0.2):
     self.betadist = beta.Beta(betaparam, betaparam)