def train_one_epoch_simclr(epoch, model, criteria_z, optim, lr_schdlr, ema, dltrain_f, lambda_s, n_iters, logger, bt, mu): """ FUNCION DE TRAIN PARA SIMCLR SOLAMENTE """ model.train() loss_meter = AverageMeter() loss_simclr_meter = AverageMeter() epoch_start = time.time() # start time dl_f = iter(dltrain_f) for it in range(n_iters): ims_s_weak, ims_s_strong, lbs_s_real = next( dl_f) # con transformaciones de simclr imgs = torch.cat([ims_s_weak, ims_s_strong], dim=0).cuda() imgs = interleave(imgs, 2 * mu) logits, logit_z, _ = model(imgs) logits_z = de_interleave(logit_z, 2 * mu) # SEPARACION DE REPRESENTACIONES PARA SIMCLR logits_s_w_z, logits_s_s_z = torch.split(logits_z, bt * mu) loss_s = criteria_z(logits_s_w_z, logits_s_s_z) loss = loss_s optim.zero_grad() loss.backward() optim.step() ema.update_params() lr_schdlr.step() loss_meter.update(loss.item()) loss_simclr_meter.update(loss_s.item()) if (it + 1) % 512 == 0: t = time.time() - epoch_start lr_log = [pg['lr'] for pg in optim.param_groups] lr_log = sum(lr_log) / len(lr_log) logger.info( "epoch:{}, iter: {}. loss: {:.4f}. " " loss_simclr: {:.4f}. LR: {:.4f}. Time: {:.2f}".format( epoch, it + 1, loss_meter.avg, loss_simclr_meter.avg, lr_log, t)) epoch_start = time.time() ema.update_buffer() return loss_meter.avg, loss_simclr_meter.avg, model
def train_one_epoch_iic(epoch, model, optim, lr_schdlr, ema, dltrain_f, n_iters, logger, bt, mu): model.train() loss_meter = AverageMeter() loss_iic_meter = AverageMeter() epoch_start = time.time() # start time dl_f = iter(dltrain_f) for it in range(n_iters): ims_s_weak, ims_s_strong, lbs_s_real = next( dl_f) # con transformaciones de simclr imgs = torch.cat([ims_s_weak, ims_s_strong], dim=0).cuda() imgs = interleave(imgs, 2 * mu) _, _, logits_iic = model(imgs) logits_iic = de_interleave(logits_iic, 2 * mu) # SEPARACION DE ULTIMAS REPRESENTACIONES PARA SIMCLR logits_iic_w, logits_iic_s = torch.split(logits_iic, bt * mu) # loss_iic = IIC_loss(logits_s_w_h, logits_s_s_h) loss_iic, P = mi_loss(logits_iic_w, logits_iic_s) loss = loss_iic optim.zero_grad() loss.backward() optim.step() ema.update_params() lr_schdlr.step() loss_meter.update(loss.item()) loss_iic_meter.update(loss_iic.item()) if (it + 1) % 512 == 0: t = time.time() - epoch_start lr_log = [pg['lr'] for pg in optim.param_groups] lr_log = sum(lr_log) / len(lr_log) logger.info("epoch:{}, iter: {}. loss: {:.4f}. " " loss_iic: {:.4f}. LR: {:.4f}. Time: {:.2f}".format( epoch, it + 1, loss_meter.avg, loss_iic_meter.avg, lr_log, t)) epoch_start = time.time() ema.update_buffer() return loss_meter.avg, loss_iic_meter.avg, model
def train_one_epoch( epoch, model, criteria_x, criteria_u, criteria_z, optim, lr_schdlr, ema, dltrain_x, dltrain_u, lb_guessor, lambda_u, n_iters, logger, ): model.train() # loss_meter, loss_x_meter, loss_u_meter, loss_u_real_meter = [], [], [], [] loss_meter = AverageMeter() loss_x_meter = AverageMeter() loss_u_meter = AverageMeter() loss_u_real_meter = AverageMeter() loss_simclr_meter = AverageMeter() # the number of correctly-predicted and gradient-considered unlabeled data n_correct_u_lbs_meter = AverageMeter() # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples n_strong_aug_meter = AverageMeter() mask_meter = AverageMeter() epoch_start = time.time() # start time dl_x, dl_u = iter(dltrain_x), iter(dltrain_u) for it in range(n_iters): ims_x_weak, ims_x_strong, lbs_x = next(dl_x) ims_u_weak, ims_u_strong, lbs_u_real = next(dl_u) lbs_x = lbs_x.cuda() lbs_u_real = lbs_u_real.cuda() # -------------------------------------- bt = ims_x_weak.size(0) mu = int(ims_u_weak.size(0) // bt) imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda() imgs = interleave(imgs, 2 * mu + 1) logits, logit_z = model(imgs) # logits = model(imgs) logits_z = de_interleave(logit_z, 2 * mu + 1) logits = de_interleave(logits, 2 * mu + 1) logits_u_w_z, logits_u_s_z = torch.split(logits_z[bt:], bt * mu) logits_x = logits[:bt] logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu) with torch.no_grad(): probs = torch.softmax(logits_u_w, dim=1) scores, lbs_u_guess = torch.max(probs, dim=1) mask = scores.ge(0.95).float() # entrenar primero con simclr el espacio h de las imagenes separadas if epoch % 2 == 0: loss_simCLR = (criteria_z(logits_u_w_z, logits_u_s_z)) with torch.no_grad(): loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean() # loss_u = torch.zeros(1) loss_x = criteria_x(logits_x, lbs_x) # loss_x = torch.zeros(1) loss = loss_simCLR else: with torch.no_grad(): loss_simCLR = (criteria_z(logits_u_w_z, logits_u_s_z)) # loss_simCLR = torch.zeros(1) loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean() loss_x = criteria_x(logits_x, lbs_x) loss = loss_x + lambda_u * loss_u loss_u_real = (F.cross_entropy(logits_u_s, lbs_u_real) * mask).mean() # -------------------------------------- # mask, lbs_u_guess = lb_guessor(model, ims_u_weak.cuda()) # n_x = ims_x_weak.size(0) # ims_x_u = torch.cat([ims_x_weak, ims_u_strong]).cuda() # logits_x_u = model(ims_x_u) # logits_x, logits_u = logits_x_u[:n_x], logits_x_u[n_x:] # loss_x = criteria_x(logits_x, lbs_x) # loss_u = (criteria_u(logits_u, lbs_u_guess) * mask).mean() # loss = loss_x + lambda_u * loss_u # loss_u_real = (F.cross_entropy(logits_u, lbs_u_real) * mask).mean() optim.zero_grad() loss.backward() optim.step() ema.update_params() lr_schdlr.step() loss_meter.update(loss.item()) loss_x_meter.update(loss_x.item()) loss_u_meter.update(loss_u.item()) loss_u_real_meter.update(loss_u_real.item()) loss_simclr_meter.update(loss_simCLR.item()) mask_meter.update(mask.mean().item()) corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask n_correct_u_lbs_meter.update(corr_u_lb.sum().item()) n_strong_aug_meter.update(mask.sum().item()) if (it + 1) % 512 == 0: t = time.time() - epoch_start lr_log = [pg['lr'] for pg in optim.param_groups] lr_log = sum(lr_log) / len(lr_log) logger.info( "epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. " " loss_simclr: {:.4f} n_correct_u: {:.2f}/{:.2f}. " "Mask:{:.4f} . LR: {:.4f}. Time: {:.2f}".format( epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, loss_u_real_meter.avg, loss_simclr_meter.avg, n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, lr_log, t)) epoch_start = time.time() ema.update_buffer() return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg,\ loss_u_real_meter.avg, loss_simclr_meter.avg, mask_meter.avg
def train_iter( self, model: Classifier, labeled_dataset: Dataset, unlabeled_dataset: Dataset) -> Generator[Stats, None, Any]: labeled_sampler = BatchSampler(RandomSampler( labeled_dataset, replacement=True, num_samples=self.num_iters*self.labeled_batch_size), batch_size=self.labeled_batch_size, drop_last=True) unlabeled_sampler = BatchSampler(RandomSampler( unlabeled_dataset, replacement=True, num_samples=self.num_iters*self.unlabeled_batch_size), batch_size=self.unlabeled_batch_size, drop_last=True) labeled_loader = DataLoader( labeled_dataset, batch_sampler=labeled_sampler, num_workers=self.num_workers, pin_memory=True) unlabeled_loader = DataLoader( unlabeled_dataset, batch_sampler=unlabeled_sampler, num_workers=self.num_workers, pin_memory=True) model.to(device=self.devices[0]) param_avg = self.param_avg_ctor(model) # set up optimizer without weight decay on batch norm or bias parameters no_wd_filter = lambda m, k: isinstance(m, nn.BatchNorm2d) or k.endswith('bias') wd_filter = lambda m, k: not no_wd_filter(m, k) optim = self.model_optimizer_ctor([ {'params': filter_parameters(model, wd_filter)}, {'params': filter_parameters(model, no_wd_filter), 'weight_decay': 0.} ]) scheduler = self.lr_scheduler_ctor(optim) scaler = torch.cuda.amp.GradScaler() if self.dist_alignment: labeled_dist = get_labeled_dist(labeled_dataset).to(self.devices[0]) prev_labels = torch.full( [self.dist_alignment_batches, model.num_classes], 1 / model.num_classes, device=self.devices[0]) prev_labels_idx = 0 # training loop for batch_idx, (b_l, b_u) in enumerate(zip(labeled_loader, unlabeled_loader)): # labeled examples xl, yl = b_l yl = yl.cuda(non_blocking=True) # augmented pairs of unlabeled examples (xw, xs), _ = b_u with torch.cuda.amp.autocast(enabled=self.mixed_precision): x = torch.cat([xl, xs, xw]).cuda(non_blocking=True) num_blocks = x.shape[0] // xl.shape[0] x = interleave(x, num_blocks) out = torch.nn.parallel.data_parallel( model, x, module_kwargs={'autocast': self.mixed_precision}, device_ids=self.devices) out = de_interleave(out, num_blocks) # get labels with torch.no_grad(): probs = torch.softmax(out[-len(xw):], -1) if self.dist_alignment: model_dist = prev_labels.mean(0) prev_labels[prev_labels_idx] = probs.mean(0) prev_labels_idx = (prev_labels_idx + 1) % self.dist_alignment_batches probs *= (labeled_dist + self.dist_alignment_eps) / (model_dist + self.dist_alignment_eps) probs /= probs.sum(-1, keepdim=True) yu = torch.argmax(probs, -1) mask = (torch.max(probs, -1)[0] >= self.threshold).to(dtype=torch.float32) loss_l = F.cross_entropy(out[:len(xl)], yl, reduction='mean') loss_u = (mask * F.cross_entropy(out[len(xl):-len(xw)], yu, reduction='none')).mean() loss = loss_l + self.unlabeled_weight * loss_u model.zero_grad() if self.mixed_precision: scaler.scale(loss).backward() scaler.step(optim) scaler.update() else: loss.backward() optim.step() param_avg.step() scheduler.step() yield self.Stats( iter=batch_idx+1, loss=loss.cpu().item(), loss_labeled=loss_l.cpu().item(), loss_unlabeled=loss_u.cpu().item(), model=model, avg_model=param_avg.avg_model, optimizer=optim, scheduler=scheduler, threshold_frac=mask.mean().cpu().item())
def train_iter( self, model: Classifier, num_classes: int, labeled_dataset: Dataset, unlabeled_dataset: Dataset) -> Generator[Stats, None, Any]: labeled_sampler = BatchSampler(RandomSampler( labeled_dataset, replacement=True, num_samples=self.batches_per_epoch * self.labeled_batch_size), batch_size=self.labeled_batch_size, drop_last=True) labeled_loader = DataLoader( labeled_dataset, batch_sampler=labeled_sampler, num_workers=self.num_workers, pin_memory=True) unlabeled_sampler = BatchSampler(RandomSampler( unlabeled_dataset, replacement=True, num_samples=self.batches_per_epoch * self.unlabeled_batch_size), batch_size=self.unlabeled_batch_size, drop_last=True) unlabeled_loader = DataLoader( unlabeled_dataset, batch_sampler=unlabeled_sampler, num_workers=self.num_workers, pin_memory=True) # initialize model and optimizer model.to(device=self.devices[0]) param_avg = self.param_avg_ctor(model) # set up optimizer without weight decay on batch norm or bias parameters no_wd_filter = lambda m, k: isinstance(m, nn.BatchNorm2d) or k.endswith('bias') wd_filter = lambda m, k: not no_wd_filter(m, k) optim = self.model_optimizer_ctor([ {'params': filter_parameters(model, wd_filter)}, {'params': filter_parameters(model, no_wd_filter), 'weight_decay': 0.} ]) scheduler = self.lr_scheduler_ctor(optim) scaler = torch.cuda.amp.GradScaler() # initialize label assignment log_upper_bounds = get_log_upper_bounds( labeled_dataset, method=self.upper_bound_method, **self.upper_bound_kwargs) logger.info('upper bounds = {}'.format(torch.exp(log_upper_bounds))) label_assgn = SinkhornLabelAllocation( num_examples=len(unlabeled_dataset), log_upper_bounds=log_upper_bounds, allocation_param=0., entropy_reg=self.entropy_reg, update_tol=self.update_tol, device=self.devices[0]) # training loop for epoch in range(self.num_epochs): # (1) update model for batch_idx, (b_l, b_u) in enumerate(zip(labeled_loader, unlabeled_loader)): # labeled examples xl, yl = b_l yl = yl.cuda(non_blocking=True) # augmented pairs of unlabeled examples (xu1, xu2), idxs = b_u with torch.cuda.amp.autocast(enabled=self.mixed_precision): x = torch.cat([xl, xu1, xu2]).cuda(non_blocking=True) if len(self.devices) > 1: num_blocks = x.shape[0] // xl.shape[0] x = interleave(x, num_blocks) out = torch.nn.parallel.data_parallel( model, x, module_kwargs={'autocast': self.mixed_precision}, device_ids=self.devices) out = de_interleave(out, num_blocks) else: out = model(x, autocast=self.mixed_precision) # compute labels logp_u = F.log_softmax(out[len(xl):], -1) nu = logp_u.shape[0] // 2 qu = label_assgn.get_plan(log_p=logp_u[:nu].detach()).to(dtype=torch.float32, device=out.device) qu = qu[:, :-1] # compute loss loss_l = F.cross_entropy(out[:len(xl)], yl, reduction='mean') loss_u = -(qu * logp_u[nu:]).sum(-1).mean() loss = loss_l + self.unlabeled_weight * loss_u # update plan rho = self.allocation_schedule( (epoch * self.batches_per_epoch + batch_idx + 1) / (self.num_epochs * self.batches_per_epoch)) label_assgn.set_allocation_param(rho) label_assgn.update_loss_matrix(logp_u[:nu], idxs) assgn_err, assgn_iters = label_assgn.update() optim.zero_grad() if self.mixed_precision: scaler.scale(loss).backward() scaler.step(optim) scaler.update() else: loss.backward() optim.step() param_avg.step() scheduler.step() yield self.Stats( iter=epoch * self.batches_per_epoch + batch_idx + 1, loss=loss.cpu().item(), loss_labeled=loss_l.cpu().item(), loss_unlabeled=loss_u.cpu().item(), model=model, avg_model=param_avg.avg_model, allocation_param=rho, optimizer=optim, scheduler=scheduler, label_vars=qu, scaling_vars=label_assgn.v.data, assgn_err=assgn_err, assgn_iters=assgn_iters)
def train_one_epoch(epoch, model, criteria_x, criteria_u, criteria_z, optim, lr_schdlr, ema, dltrain_x, dltrain_u, dltrain_f, lb_guessor, lambda_u, lambda_s, n_iters, logger, bt, mu): """ FUNCION DE ENTRENAMIENTO PARA FIXMATCH Y SIMCLR EN LA MISMA EPOCA """ model.train() # loss_meter, loss_x_meter, loss_u_meter, loss_u_real_meter = [], [], [], [] loss_meter = AverageMeter() loss_x_meter = AverageMeter() loss_u_meter = AverageMeter() loss_u_real_meter = AverageMeter() loss_simclr_meter = AverageMeter() # the number of correctly-predicted and gradient-considered unlabeled data n_correct_u_lbs_meter = AverageMeter() # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples n_strong_aug_meter = AverageMeter() mask_meter = AverageMeter() epoch_start = time.time() # start time dl_x, dl_u, dl_f = iter(dltrain_x), iter(dltrain_u), iter(dltrain_f) # dl_x, dl_u = iter(dltrain_x), iter(dltrain_u) for it in range(n_iters): ims_x_weak, ims_x_strong, lbs_x = next(dl_x) ims_u_weak, ims_u_strong, lbs_u_real = next( dl_u) # transformaciones de fixmatch ims_s_weak, ims_s_strong, lbs_s_real = next( dl_f) # con transformaciones de simclr lbs_x = lbs_x.cuda() lbs_u_real = lbs_u_real.cuda() # -------------------------------------- imgs = torch.cat( [ims_x_weak, ims_u_weak, ims_u_strong, ims_s_weak, ims_s_strong], dim=0).cuda() # imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda() imgs = interleave(imgs, 4 * mu + 1) # imgs = interleave(imgs, 2 * mu + 1) logits, logit_z, _ = model(imgs) logits = de_interleave(logits, 4 * mu + 1) # logits = de_interleave(logits, 2 * mu + 1) # SEPARACION DE LOGITS PARA ETAPA SUPERVISADA DE FIXMATCH logits_x = logits[:bt] # SEPARACION DE LOGITS PARA ETAPA NO SUPERVISADA DE FIXMATCH logits_u_w, logits_u_s, _, _ = torch.split(logits[bt:], bt * mu) # SEPARACION DE LOGITS PARA ETAPA NO SUPERVISADA DE SIMCLR _, _, logits_s_w, logits_s_s = torch.split(logit_z[bt:], bt * mu) # calculo de la mascara con transformacion debil de fixmatch with torch.no_grad(): probs = torch.softmax(logits_u_w, dim=1) scores, lbs_u_guess = torch.max(probs, dim=1) mask = scores.ge(0.95).float() # calcular perdida loss_s = criteria_z(logits_s_w, logits_s_s) loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean() loss_x = criteria_x(logits_x, lbs_x) loss = loss_x + loss_u * lambda_u + loss_s * lambda_s loss_u_real = (F.cross_entropy(logits_u_s, lbs_u_real) * mask).mean() optim.zero_grad() loss.backward() optim.step() ema.update_params() lr_schdlr.step() loss_meter.update(loss.item()) loss_x_meter.update(loss_x.item()) loss_u_meter.update(loss_u.item()) loss_u_real_meter.update(loss_u_real.item()) loss_simclr_meter.update(loss_s.item()) mask_meter.update(mask.mean().item()) corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask n_correct_u_lbs_meter.update(corr_u_lb.sum().item()) n_strong_aug_meter.update(mask.sum().item()) if (it + 1) % 512 == 0: t = time.time() - epoch_start lr_log = [pg['lr'] for pg in optim.param_groups] lr_log = sum(lr_log) / len(lr_log) logger.info( "epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. " "n_correct_u: {:.2f}/{:.2f}. loss_s: {:.4f}. " "Mask:{:.4f}. LR: {:.4f}. Time: {:.2f}".format( epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, loss_u_real_meter.avg, n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, loss_simclr_meter.avg, mask_meter.avg, lr_log, t)) # logger.info("epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. " # "n_correct_u: {:.2f}/{:.2f}." # "Mask:{:.4f} . LR: {:.4f}. Time: {:.2f}".format( # epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, loss_u_real_meter.avg, # n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, lr_log, t)) epoch_start = time.time() ema.update_buffer() return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg,\ loss_u_real_meter.avg, mask_meter.avg, loss_simclr_meter.avg