示例#1
0
    def train_model(self, model, optim, batch_data, params: DivideMixParams, meter: Meter):
        model.train()

        # sup, unsup = batch_data
        sup = batch_data
        (xs, xs2, nys, prob) = sup

        try:
            if self.unlabeled_dataloader_iter is None:
                self.unlabeled_dataloader_iter = iter(self.unlabeled_dataloader)
            unsup = next(self.unlabeled_dataloader_iter)
        except:
            self.unlabeled_dataloader_iter = iter(self.unlabeled_dataloader)
            unsup = next(self.unlabeled_dataloader_iter)

        (uxs, uxs2, unys) = unsup
        (uxs, uxs2, unys) = (uxs.to(self.device), uxs2.to(self.device), unys.to(self.device))

        n_targets = tricks.onehot(nys, params.n_classes)
        # nys = torch.zeros(params.batch_size, params.n_classes, device=self.device).scatter_(1, nys.view(-1, 1), 1)
        prob = prob.view(-1, 1).float()
        batch_size = xs.shape[0]
        with torch.no_grad():
            # label co-guessing of unlabeled samples
            outputs_u11 = model(uxs)
            outputs_u12 = model(uxs2)

            pu = self.label_guesses_(outputs_u11, outputs_u12)
            targets_u = self.sharpen_(pu, params.T)  # temparature sharpening

            # label refinement of labeled samples
            outputs_x = model(xs)
            outputs_x2 = model(xs2)

            px = self.label_guesses_(outputs_x, outputs_x2)
            px = prob * n_targets + (1 - prob) * px
            targets_x = self.sharpen_(px, params.T)  # temparature sharpening

            self.acc_precise_(outputs_x.argmax(dim=-1), nys, meter=meter)
            self.acc_precise_(outputs_u11.argmax(dim=-1), unys, meter=meter, name='uacc')

            # mixmatch
        l = np.random.beta(params.mix_beta, params.mix_beta)
        l = max(l, 1 - l)

        all_inputs = torch.cat([xs, xs2, uxs, uxs2], dim=0)
        all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)

        idx = torch.randperm(all_inputs.shape[0])

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a[:batch_size * 2] + (1 - l) * input_b[:batch_size * 2]
        mixed_target = l * target_a[:batch_size * 2] + (1 - l) * target_b[:batch_size * 2]

        logits = model(mixed_input)
        # logits_x = model(mixed_input[:batch_size * 2])
        # logits_u = model(mixed_input[batch_size * 2:])
        logits_x = logits

        meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits_x, mixed_target[:batch_size * 2],
                                                             meter=meter, name='Lx')

        # regularization
        prior = torch.ones(params.n_classes, device=self.device) / params.n_classes
        pred_mean = torch.softmax(logits, dim=1).mean(0)
        meter.Lpen = torch.sum(prior * torch.log(prior / pred_mean))  # penalty

        meter.Lall = meter.Lall + meter.Lpen

        # compute gradient and do SGD step
        optim.zero_grad()
        meter.Lall.backward()
        optim.step()

        return meter
示例#2
0
    def train_model(self, model, model2, optim, batch_data,
                    params: DivideMixParams, meter: Meter):
        model.train()
        model2.eval()

        sup, unsup = batch_data
        (xs, xs2, ys, nys, prob) = sup
        (uxs, uxs2, uys, unys) = unsup

        n_targets = tricks.onehot(nys, params.n_classes)
        # nys = torch.zeros(params.batch_size, params.n_classes, device=self.device).scatter_(1, nys.view(-1, 1), 1)
        prob = prob.view(-1, 1).float()
        batch_size = xs.shape[0]
        with torch.no_grad():
            # label co-guessing of unlabeled samples
            outputs_u11 = model(uxs)
            outputs_u12 = model(uxs2)
            outputs_u21 = model2(uxs)
            outputs_u22 = model2(uxs2)

            pu = self.label_guesses_(outputs_u11, outputs_u12, outputs_u21,
                                     outputs_u22)
            targets_u = self.sharpen_(pu, params.T)  # temparature sharpening

            # label refinement of labeled samples
            outputs_x = model(xs)
            outputs_x2 = model(xs2)

            px = self.label_guesses_(outputs_x, outputs_x2)
            px = prob * n_targets + (1 - prob) * px
            targets_x = self.sharpen_(px, params.T)  # temparature sharpening

        # mixmatch
        l = np.random.beta(params.mix_beta, params.mix_beta)
        l = max(l, 1 - l)

        all_inputs = torch.cat([xs, xs2, uxs, uxs2], dim=0)
        all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u],
                                dim=0)

        idx = torch.randperm(all_inputs.shape[0], device=self.device)

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        logits = model(mixed_input)
        logits_x = logits[:batch_size * 2]
        logits_u = logits[batch_size * 2:]

        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            logits_x, mixed_target[:batch_size * 2], meter=meter, name='Lx')
        meter.Lall = meter.Lall + self.loss_mse_(
            logits_u,
            mixed_target[batch_size * 2:],
            w_mse=params.rampup_sche(params.eidx) * 25,
            meter=meter,
            name='Lu')

        # regularization
        prior = torch.ones(params.n_classes,
                           device=self.device) / params.n_classes

        pred_mean = torch.softmax(logits, dim=1).mean(0)
        meter.Lpen = torch.sum(prior * torch.log(prior / pred_mean))  # penalty

        meter.Lall = meter.Lall + meter.Lpen

        # compute gradient and do SGD step
        optim.zero_grad()
        meter.Lall.backward()
        optim.step()