def train_one_epoch_simclr(epoch, model, criteria_z, optim, lr_schdlr, ema,
                           dltrain_f, lambda_s, n_iters, logger, bt, mu):
    """
    FUNCION DE TRAIN PARA SIMCLR SOLAMENTE
    """
    model.train()

    loss_meter = AverageMeter()
    loss_simclr_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_f = iter(dltrain_f)
    for it in range(n_iters):
        ims_s_weak, ims_s_strong, lbs_s_real = next(
            dl_f)  # con transformaciones de simclr

        imgs = torch.cat([ims_s_weak, ims_s_strong], dim=0).cuda()
        imgs = interleave(imgs, 2 * mu)
        logits, logit_z, _ = model(imgs)
        logits_z = de_interleave(logit_z, 2 * mu)

        # SEPARACION DE REPRESENTACIONES PARA SIMCLR
        logits_s_w_z, logits_s_s_z = torch.split(logits_z, bt * mu)

        loss_s = criteria_z(logits_s_w_z, logits_s_s_z)

        loss = loss_s

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_simclr_meter.update(loss_s.item())

        if (it + 1) % 512 == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info(
                "epoch:{}, iter: {}. loss: {:.4f}. "
                " loss_simclr: {:.4f}. LR: {:.4f}. Time: {:.2f}".format(
                    epoch, it + 1, loss_meter.avg, loss_simclr_meter.avg,
                    lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()
    return loss_meter.avg, loss_simclr_meter.avg, model
def train_one_epoch_iic(epoch, model, optim, lr_schdlr, ema, dltrain_f,
                        n_iters, logger, bt, mu):
    model.train()
    loss_meter = AverageMeter()
    loss_iic_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_f = iter(dltrain_f)
    for it in range(n_iters):
        ims_s_weak, ims_s_strong, lbs_s_real = next(
            dl_f)  # con transformaciones de simclr

        imgs = torch.cat([ims_s_weak, ims_s_strong], dim=0).cuda()
        imgs = interleave(imgs, 2 * mu)
        _, _, logits_iic = model(imgs)
        logits_iic = de_interleave(logits_iic, 2 * mu)

        # SEPARACION DE ULTIMAS REPRESENTACIONES PARA SIMCLR
        logits_iic_w, logits_iic_s = torch.split(logits_iic, bt * mu)

        # loss_iic = IIC_loss(logits_s_w_h, logits_s_s_h)
        loss_iic, P = mi_loss(logits_iic_w, logits_iic_s)

        loss = loss_iic

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_iic_meter.update(loss_iic.item())

        if (it + 1) % 512 == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info("epoch:{}, iter: {}. loss: {:.4f}. "
                        " loss_iic: {:.4f}. LR: {:.4f}. Time: {:.2f}".format(
                            epoch, it + 1, loss_meter.avg, loss_iic_meter.avg,
                            lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()
    return loss_meter.avg, loss_iic_meter.avg, model
def train_one_epoch(
    epoch,
    model,
    criteria_x,
    criteria_u,
    criteria_z,
    optim,
    lr_schdlr,
    ema,
    dltrain_x,
    dltrain_u,
    lb_guessor,
    lambda_u,
    n_iters,
    logger,
):
    model.train()
    # loss_meter, loss_x_meter, loss_u_meter, loss_u_real_meter = [], [], [], []
    loss_meter = AverageMeter()
    loss_x_meter = AverageMeter()
    loss_u_meter = AverageMeter()
    loss_u_real_meter = AverageMeter()
    loss_simclr_meter = AverageMeter()
    # the number of correctly-predicted and gradient-considered unlabeled data
    n_correct_u_lbs_meter = AverageMeter()
    # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples
    n_strong_aug_meter = AverageMeter()
    mask_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_x, dl_u = iter(dltrain_x), iter(dltrain_u)
    for it in range(n_iters):
        ims_x_weak, ims_x_strong, lbs_x = next(dl_x)
        ims_u_weak, ims_u_strong, lbs_u_real = next(dl_u)

        lbs_x = lbs_x.cuda()
        lbs_u_real = lbs_u_real.cuda()

        # --------------------------------------

        bt = ims_x_weak.size(0)
        mu = int(ims_u_weak.size(0) // bt)
        imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda()
        imgs = interleave(imgs, 2 * mu + 1)
        logits, logit_z = model(imgs)
        # logits = model(imgs)
        logits_z = de_interleave(logit_z, 2 * mu + 1)
        logits = de_interleave(logits, 2 * mu + 1)

        logits_u_w_z, logits_u_s_z = torch.split(logits_z[bt:], bt * mu)

        logits_x = logits[:bt]
        logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu)

        with torch.no_grad():
            probs = torch.softmax(logits_u_w, dim=1)
            scores, lbs_u_guess = torch.max(probs, dim=1)
            mask = scores.ge(0.95).float()

        # entrenar primero con simclr el espacio h de las imagenes separadas
        if epoch % 2 == 0:
            loss_simCLR = (criteria_z(logits_u_w_z, logits_u_s_z))

            with torch.no_grad():
                loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean()
                # loss_u = torch.zeros(1)
                loss_x = criteria_x(logits_x, lbs_x)
                # loss_x = torch.zeros(1)

            loss = loss_simCLR
        else:
            with torch.no_grad():
                loss_simCLR = (criteria_z(logits_u_w_z, logits_u_s_z))
                # loss_simCLR = torch.zeros(1)

            loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean()
            loss_x = criteria_x(logits_x, lbs_x)
            loss = loss_x + lambda_u * loss_u
        loss_u_real = (F.cross_entropy(logits_u_s, lbs_u_real) * mask).mean()

        # --------------------------------------

        # mask, lbs_u_guess = lb_guessor(model, ims_u_weak.cuda())
        # n_x = ims_x_weak.size(0)
        # ims_x_u = torch.cat([ims_x_weak, ims_u_strong]).cuda()
        # logits_x_u = model(ims_x_u)
        # logits_x, logits_u = logits_x_u[:n_x], logits_x_u[n_x:]
        # loss_x = criteria_x(logits_x, lbs_x)
        # loss_u = (criteria_u(logits_u, lbs_u_guess) * mask).mean()
        # loss = loss_x + lambda_u * loss_u
        # loss_u_real = (F.cross_entropy(logits_u, lbs_u_real) * mask).mean()

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_x_meter.update(loss_x.item())
        loss_u_meter.update(loss_u.item())
        loss_u_real_meter.update(loss_u_real.item())
        loss_simclr_meter.update(loss_simCLR.item())
        mask_meter.update(mask.mean().item())

        corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask
        n_correct_u_lbs_meter.update(corr_u_lb.sum().item())
        n_strong_aug_meter.update(mask.sum().item())

        if (it + 1) % 512 == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info(
                "epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. "
                " loss_simclr: {:.4f} n_correct_u: {:.2f}/{:.2f}. "
                "Mask:{:.4f} . LR: {:.4f}. Time: {:.2f}".format(
                    epoch, it + 1, loss_meter.avg, loss_u_meter.avg,
                    loss_x_meter.avg, loss_u_real_meter.avg,
                    loss_simclr_meter.avg, n_correct_u_lbs_meter.avg,
                    n_strong_aug_meter.avg, mask_meter.avg, lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()
    return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg,\
           loss_u_real_meter.avg, loss_simclr_meter.avg, mask_meter.avg
Beispiel #4
0
    def train_iter(
            self,
            model: Classifier,
            labeled_dataset: Dataset,
            unlabeled_dataset: Dataset) -> Generator[Stats, None, Any]:

        labeled_sampler = BatchSampler(RandomSampler(
            labeled_dataset, replacement=True, num_samples=self.num_iters*self.labeled_batch_size),
            batch_size=self.labeled_batch_size, drop_last=True)
        unlabeled_sampler = BatchSampler(RandomSampler(
            unlabeled_dataset, replacement=True, num_samples=self.num_iters*self.unlabeled_batch_size),
            batch_size=self.unlabeled_batch_size, drop_last=True)
        labeled_loader = DataLoader(
            labeled_dataset, batch_sampler=labeled_sampler, num_workers=self.num_workers, pin_memory=True)
        unlabeled_loader = DataLoader(
            unlabeled_dataset, batch_sampler=unlabeled_sampler, num_workers=self.num_workers, pin_memory=True)

        model.to(device=self.devices[0])
        param_avg = self.param_avg_ctor(model)

        # set up optimizer without weight decay on batch norm or bias parameters
        no_wd_filter = lambda m, k: isinstance(m, nn.BatchNorm2d) or k.endswith('bias')
        wd_filter = lambda m, k: not no_wd_filter(m, k)
        optim = self.model_optimizer_ctor([
            {'params': filter_parameters(model, wd_filter)},
            {'params': filter_parameters(model, no_wd_filter), 'weight_decay': 0.}
        ])

        scheduler = self.lr_scheduler_ctor(optim)
        scaler = torch.cuda.amp.GradScaler()

        if self.dist_alignment:
            labeled_dist = get_labeled_dist(labeled_dataset).to(self.devices[0])
            prev_labels = torch.full(
                [self.dist_alignment_batches, model.num_classes], 1 / model.num_classes, device=self.devices[0])
            prev_labels_idx = 0

        # training loop
        for batch_idx, (b_l, b_u) in enumerate(zip(labeled_loader, unlabeled_loader)):
            # labeled examples
            xl, yl = b_l
            yl = yl.cuda(non_blocking=True)

            # augmented pairs of unlabeled examples
            (xw, xs), _ = b_u

            with torch.cuda.amp.autocast(enabled=self.mixed_precision):
                x = torch.cat([xl, xs, xw]).cuda(non_blocking=True)
                num_blocks = x.shape[0] // xl.shape[0]
                x = interleave(x, num_blocks)
                out = torch.nn.parallel.data_parallel(
                    model, x, module_kwargs={'autocast': self.mixed_precision}, device_ids=self.devices)
                out = de_interleave(out, num_blocks)

                # get labels
                with torch.no_grad():
                    probs = torch.softmax(out[-len(xw):], -1)
                    if self.dist_alignment:
                        model_dist = prev_labels.mean(0)
                        prev_labels[prev_labels_idx] = probs.mean(0)
                        prev_labels_idx = (prev_labels_idx + 1) % self.dist_alignment_batches
                        probs *= (labeled_dist + self.dist_alignment_eps) / (model_dist + self.dist_alignment_eps)
                        probs /= probs.sum(-1, keepdim=True)
                    yu = torch.argmax(probs, -1)
                    mask = (torch.max(probs, -1)[0] >= self.threshold).to(dtype=torch.float32)

                loss_l = F.cross_entropy(out[:len(xl)], yl, reduction='mean')
                loss_u = (mask * F.cross_entropy(out[len(xl):-len(xw)], yu, reduction='none')).mean()
                loss = loss_l + self.unlabeled_weight * loss_u

            model.zero_grad()
            if self.mixed_precision:
                scaler.scale(loss).backward()
                scaler.step(optim)
                scaler.update()
            else:
                loss.backward()
                optim.step()
            param_avg.step()
            scheduler.step()

            yield self.Stats(
                iter=batch_idx+1,
                loss=loss.cpu().item(),
                loss_labeled=loss_l.cpu().item(),
                loss_unlabeled=loss_u.cpu().item(),
                model=model,
                avg_model=param_avg.avg_model,
                optimizer=optim,
                scheduler=scheduler,
                threshold_frac=mask.mean().cpu().item())
Beispiel #5
0
    def train_iter(
            self,
            model: Classifier,
            num_classes: int,
            labeled_dataset: Dataset,
            unlabeled_dataset: Dataset) -> Generator[Stats, None, Any]:

        labeled_sampler = BatchSampler(RandomSampler(
            labeled_dataset, replacement=True, num_samples=self.batches_per_epoch * self.labeled_batch_size),
            batch_size=self.labeled_batch_size, drop_last=True)
        labeled_loader = DataLoader(
            labeled_dataset, batch_sampler=labeled_sampler, num_workers=self.num_workers, pin_memory=True)
        unlabeled_sampler = BatchSampler(RandomSampler(
            unlabeled_dataset, replacement=True,
            num_samples=self.batches_per_epoch * self.unlabeled_batch_size),
            batch_size=self.unlabeled_batch_size, drop_last=True)
        unlabeled_loader = DataLoader(
            unlabeled_dataset, batch_sampler=unlabeled_sampler, num_workers=self.num_workers, pin_memory=True)

        # initialize model and optimizer
        model.to(device=self.devices[0])
        param_avg = self.param_avg_ctor(model)

        # set up optimizer without weight decay on batch norm or bias parameters
        no_wd_filter = lambda m, k: isinstance(m, nn.BatchNorm2d) or k.endswith('bias')
        wd_filter = lambda m, k: not no_wd_filter(m, k)
        optim = self.model_optimizer_ctor([
            {'params': filter_parameters(model, wd_filter)},
            {'params': filter_parameters(model, no_wd_filter), 'weight_decay': 0.}
        ])

        scheduler = self.lr_scheduler_ctor(optim)
        scaler = torch.cuda.amp.GradScaler()

        # initialize label assignment
        log_upper_bounds = get_log_upper_bounds(
            labeled_dataset, method=self.upper_bound_method, **self.upper_bound_kwargs)
        logger.info('upper bounds = {}'.format(torch.exp(log_upper_bounds)))
        label_assgn = SinkhornLabelAllocation(
            num_examples=len(unlabeled_dataset),
            log_upper_bounds=log_upper_bounds,
            allocation_param=0.,
            entropy_reg=self.entropy_reg,
            update_tol=self.update_tol,
            device=self.devices[0])

        # training loop
        for epoch in range(self.num_epochs):
            # (1) update model
            for batch_idx, (b_l, b_u) in enumerate(zip(labeled_loader, unlabeled_loader)):
                # labeled examples
                xl, yl = b_l
                yl = yl.cuda(non_blocking=True)

                # augmented pairs of unlabeled examples
                (xu1, xu2), idxs = b_u

                with torch.cuda.amp.autocast(enabled=self.mixed_precision):
                    x = torch.cat([xl, xu1, xu2]).cuda(non_blocking=True)
                    if len(self.devices) > 1:
                        num_blocks = x.shape[0] // xl.shape[0]
                        x = interleave(x, num_blocks)
                        out = torch.nn.parallel.data_parallel(
                            model, x, module_kwargs={'autocast': self.mixed_precision}, device_ids=self.devices)
                        out = de_interleave(out, num_blocks)
                    else:
                        out = model(x, autocast=self.mixed_precision)

                    # compute labels
                    logp_u = F.log_softmax(out[len(xl):], -1)
                    nu = logp_u.shape[0] // 2
                    qu = label_assgn.get_plan(log_p=logp_u[:nu].detach()).to(dtype=torch.float32, device=out.device)
                    qu = qu[:, :-1]

                    # compute loss
                    loss_l = F.cross_entropy(out[:len(xl)], yl, reduction='mean')
                    loss_u = -(qu * logp_u[nu:]).sum(-1).mean()
                    loss = loss_l + self.unlabeled_weight * loss_u

                    # update plan
                    rho = self.allocation_schedule(
                        (epoch * self.batches_per_epoch + batch_idx + 1) /
                        (self.num_epochs * self.batches_per_epoch))
                    label_assgn.set_allocation_param(rho)
                    label_assgn.update_loss_matrix(logp_u[:nu], idxs)
                    assgn_err, assgn_iters = label_assgn.update()

                optim.zero_grad()
                if self.mixed_precision:
                    scaler.scale(loss).backward()
                    scaler.step(optim)
                    scaler.update()
                else:
                    loss.backward()
                    optim.step()
                param_avg.step()
                scheduler.step()

                yield self.Stats(
                    iter=epoch * self.batches_per_epoch + batch_idx + 1,
                    loss=loss.cpu().item(),
                    loss_labeled=loss_l.cpu().item(),
                    loss_unlabeled=loss_u.cpu().item(),
                    model=model,
                    avg_model=param_avg.avg_model,
                    allocation_param=rho,
                    optimizer=optim,
                    scheduler=scheduler,
                    label_vars=qu,
                    scaling_vars=label_assgn.v.data,
                    assgn_err=assgn_err,
                    assgn_iters=assgn_iters)
def train_one_epoch(epoch, model, criteria_x, criteria_u, criteria_z, optim,
                    lr_schdlr, ema, dltrain_x, dltrain_u, dltrain_f,
                    lb_guessor, lambda_u, lambda_s, n_iters, logger, bt, mu):
    """
    FUNCION DE ENTRENAMIENTO PARA FIXMATCH Y SIMCLR EN LA MISMA EPOCA
    """
    model.train()
    # loss_meter, loss_x_meter, loss_u_meter, loss_u_real_meter = [], [], [], []
    loss_meter = AverageMeter()
    loss_x_meter = AverageMeter()
    loss_u_meter = AverageMeter()
    loss_u_real_meter = AverageMeter()
    loss_simclr_meter = AverageMeter()
    # the number of correctly-predicted and gradient-considered unlabeled data
    n_correct_u_lbs_meter = AverageMeter()
    # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples
    n_strong_aug_meter = AverageMeter()
    mask_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_x, dl_u, dl_f = iter(dltrain_x), iter(dltrain_u), iter(dltrain_f)
    # dl_x, dl_u = iter(dltrain_x), iter(dltrain_u)
    for it in range(n_iters):
        ims_x_weak, ims_x_strong, lbs_x = next(dl_x)
        ims_u_weak, ims_u_strong, lbs_u_real = next(
            dl_u)  # transformaciones de fixmatch
        ims_s_weak, ims_s_strong, lbs_s_real = next(
            dl_f)  # con transformaciones de simclr

        lbs_x = lbs_x.cuda()
        lbs_u_real = lbs_u_real.cuda()

        # --------------------------------------
        imgs = torch.cat(
            [ims_x_weak, ims_u_weak, ims_u_strong, ims_s_weak, ims_s_strong],
            dim=0).cuda()
        # imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda()
        imgs = interleave(imgs, 4 * mu + 1)
        # imgs = interleave(imgs, 2 * mu + 1)
        logits, logit_z, _ = model(imgs)
        logits = de_interleave(logits, 4 * mu + 1)
        # logits = de_interleave(logits, 2 * mu + 1)

        # SEPARACION DE LOGITS PARA ETAPA SUPERVISADA DE FIXMATCH
        logits_x = logits[:bt]
        # SEPARACION DE LOGITS PARA ETAPA NO SUPERVISADA DE FIXMATCH
        logits_u_w, logits_u_s, _, _ = torch.split(logits[bt:], bt * mu)
        # SEPARACION DE LOGITS PARA ETAPA NO SUPERVISADA DE SIMCLR
        _, _, logits_s_w, logits_s_s = torch.split(logit_z[bt:], bt * mu)

        # calculo de la mascara con transformacion debil de fixmatch
        with torch.no_grad():
            probs = torch.softmax(logits_u_w, dim=1)
            scores, lbs_u_guess = torch.max(probs, dim=1)
            mask = scores.ge(0.95).float()

        # calcular perdida
        loss_s = criteria_z(logits_s_w, logits_s_s)
        loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean()
        loss_x = criteria_x(logits_x, lbs_x)

        loss = loss_x + loss_u * lambda_u + loss_s * lambda_s

        loss_u_real = (F.cross_entropy(logits_u_s, lbs_u_real) * mask).mean()

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_x_meter.update(loss_x.item())
        loss_u_meter.update(loss_u.item())
        loss_u_real_meter.update(loss_u_real.item())
        loss_simclr_meter.update(loss_s.item())
        mask_meter.update(mask.mean().item())

        corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask
        n_correct_u_lbs_meter.update(corr_u_lb.sum().item())
        n_strong_aug_meter.update(mask.sum().item())

        if (it + 1) % 512 == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info(
                "epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. "
                "n_correct_u: {:.2f}/{:.2f}. loss_s: {:.4f}. "
                "Mask:{:.4f}. LR: {:.4f}. Time: {:.2f}".format(
                    epoch, it + 1, loss_meter.avg, loss_u_meter.avg,
                    loss_x_meter.avg, loss_u_real_meter.avg,
                    n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg,
                    loss_simclr_meter.avg, mask_meter.avg, lr_log, t))

            # logger.info("epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. "
            #             "n_correct_u: {:.2f}/{:.2f}."
            #             "Mask:{:.4f} . LR: {:.4f}. Time: {:.2f}".format(
            #     epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, loss_u_real_meter.avg,
            #     n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()
    return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg,\
           loss_u_real_meter.avg, mask_meter.avg, loss_simclr_meter.avg