def train_model(self, model, optim, batch_data, params: DivideMixParams, meter: Meter): model.train() # sup, unsup = batch_data sup = batch_data (xs, xs2, nys, prob) = sup try: if self.unlabeled_dataloader_iter is None: self.unlabeled_dataloader_iter = iter(self.unlabeled_dataloader) unsup = next(self.unlabeled_dataloader_iter) except: self.unlabeled_dataloader_iter = iter(self.unlabeled_dataloader) unsup = next(self.unlabeled_dataloader_iter) (uxs, uxs2, unys) = unsup (uxs, uxs2, unys) = (uxs.to(self.device), uxs2.to(self.device), unys.to(self.device)) n_targets = tricks.onehot(nys, params.n_classes) # nys = torch.zeros(params.batch_size, params.n_classes, device=self.device).scatter_(1, nys.view(-1, 1), 1) prob = prob.view(-1, 1).float() batch_size = xs.shape[0] with torch.no_grad(): # label co-guessing of unlabeled samples outputs_u11 = model(uxs) outputs_u12 = model(uxs2) pu = self.label_guesses_(outputs_u11, outputs_u12) targets_u = self.sharpen_(pu, params.T) # temparature sharpening # label refinement of labeled samples outputs_x = model(xs) outputs_x2 = model(xs2) px = self.label_guesses_(outputs_x, outputs_x2) px = prob * n_targets + (1 - prob) * px targets_x = self.sharpen_(px, params.T) # temparature sharpening self.acc_precise_(outputs_x.argmax(dim=-1), nys, meter=meter) self.acc_precise_(outputs_u11.argmax(dim=-1), unys, meter=meter, name='uacc') # mixmatch l = np.random.beta(params.mix_beta, params.mix_beta) l = max(l, 1 - l) all_inputs = torch.cat([xs, xs2, uxs, uxs2], dim=0) all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0) idx = torch.randperm(all_inputs.shape[0]) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a[:batch_size * 2] + (1 - l) * input_b[:batch_size * 2] mixed_target = l * target_a[:batch_size * 2] + (1 - l) * target_b[:batch_size * 2] logits = model(mixed_input) # logits_x = model(mixed_input[:batch_size * 2]) # logits_u = model(mixed_input[batch_size * 2:]) logits_x = logits meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits_x, mixed_target[:batch_size * 2], meter=meter, name='Lx') # regularization prior = torch.ones(params.n_classes, device=self.device) / params.n_classes pred_mean = torch.softmax(logits, dim=1).mean(0) meter.Lpen = torch.sum(prior * torch.log(prior / pred_mean)) # penalty meter.Lall = meter.Lall + meter.Lpen # compute gradient and do SGD step optim.zero_grad() meter.Lall.backward() optim.step() return meter
def train_model(self, model, model2, optim, batch_data, params: DivideMixParams, meter: Meter): model.train() model2.eval() sup, unsup = batch_data (xs, xs2, ys, nys, prob) = sup (uxs, uxs2, uys, unys) = unsup n_targets = tricks.onehot(nys, params.n_classes) # nys = torch.zeros(params.batch_size, params.n_classes, device=self.device).scatter_(1, nys.view(-1, 1), 1) prob = prob.view(-1, 1).float() batch_size = xs.shape[0] with torch.no_grad(): # label co-guessing of unlabeled samples outputs_u11 = model(uxs) outputs_u12 = model(uxs2) outputs_u21 = model2(uxs) outputs_u22 = model2(uxs2) pu = self.label_guesses_(outputs_u11, outputs_u12, outputs_u21, outputs_u22) targets_u = self.sharpen_(pu, params.T) # temparature sharpening # label refinement of labeled samples outputs_x = model(xs) outputs_x2 = model(xs2) px = self.label_guesses_(outputs_x, outputs_x2) px = prob * n_targets + (1 - prob) * px targets_x = self.sharpen_(px, params.T) # temparature sharpening # mixmatch l = np.random.beta(params.mix_beta, params.mix_beta) l = max(l, 1 - l) all_inputs = torch.cat([xs, xs2, uxs, uxs2], dim=0) all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0) idx = torch.randperm(all_inputs.shape[0], device=self.device) input_a, input_b = all_inputs, all_inputs[idx] target_a, target_b = all_targets, all_targets[idx] mixed_input = l * input_a + (1 - l) * input_b mixed_target = l * target_a + (1 - l) * target_b logits = model(mixed_input) logits_x = logits[:batch_size * 2] logits_u = logits[batch_size * 2:] meter.Lall = meter.Lall + self.loss_ce_with_targets_( logits_x, mixed_target[:batch_size * 2], meter=meter, name='Lx') meter.Lall = meter.Lall + self.loss_mse_( logits_u, mixed_target[batch_size * 2:], w_mse=params.rampup_sche(params.eidx) * 25, meter=meter, name='Lu') # regularization prior = torch.ones(params.n_classes, device=self.device) / params.n_classes pred_mean = torch.softmax(logits, dim=1).mean(0) meter.Lpen = torch.sum(prior * torch.log(prior / pred_mean)) # penalty meter.Lall = meter.Lall + meter.Lpen # compute gradient and do SGD step optim.zero_grad() meter.Lall.backward() optim.step()