示例#1
0
def train(trainX, trainU, model, ema_model, optimizer, epoch):
    xe_loss_avg = tf.keras.metrics.Mean()
    l2u_loss_avg = tf.keras.metrics.Mean()
    total_loss_avg = tf.keras.metrics.Mean()
    accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

    shuffle_and_batch = lambda dataset: dataset.shuffle(buffer_size=int(1e6)).batch(batch_size=64, drop_remainder=True)

    iteratorX = iter(shuffle_and_batch(trainX))
    iteratorU = iter(shuffle_and_batch(trainU))

    progress_bar = tqdm(range(1024), unit='batch')
    for batch_num in progress_bar:
        lambda_u = 100 * linear_rampup(epoch + batch_num/1024, 16)
        try:
            batchX = next(iteratorX)
        except:
            iteratorX = iter(shuffle_and_batch(trainX))
            batchX = next(iteratorX)
        try:
            batchU = next(iteratorU)
        except:
            iteratorU = iter(shuffle_and_batch(trainU))
            batchU = next(iteratorU)

        #args['beta'].assign(np.random.beta(args['alpha'], args['alpha']))
        beta = np.random.beta(0.75,0.75)
        with tf.GradientTape() as tape:
            # run mixmatch
            XU, XUy = mixmatch(model, batchX['image'], batchX['label'], batchU['image'], 0.5, 2, beta)
            logits = [model(XU[0])]
            for batch in XU[1:]:
                logits.append(model(batch))
            logits = interleave(logits, 64)
            logits_x = logits[0]
            logits_u = tf.concat(logits[1:], axis=0)

            # compute loss
            xe_loss, l2u_loss = semi_loss(XUy[:64], logits_x, XUy[64:], logits_u)
            total_loss = xe_loss + lambda_u * l2u_loss

        # compute gradients and run optimizer step
        grads = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        ema(model, ema_model, 0.999)
        weight_decay(model=model, decay_rate=0.02 * 0.01)

        xe_loss_avg(xe_loss)
        l2u_loss_avg(l2u_loss)
        total_loss_avg(total_loss)
        accuracy(tf.argmax(batchX['label'], axis=1, output_type=tf.int32), model(tf.cast(batchX['image'], dtype=tf.float32), training=False))

        progress_bar.set_postfix({
            'XE Loss': f'{xe_loss_avg.result():.4f}',
            'L2U Loss': f'{l2u_loss_avg.result():.4f}',
            'WeightU': f'{lambda_u:.3f}',
            'Total Loss': f'{total_loss_avg.result():.4f}',
            'Accuracy': f'{accuracy.result():.3%}'
        })
    return xe_loss_avg, l2u_loss_avg, total_loss_avg, accuracy
def train_one_epoch_simclr(epoch, model, criteria_z, optim, lr_schdlr, ema,
                           dltrain_f, lambda_s, n_iters, logger, bt, mu):
    """
    FUNCION DE TRAIN PARA SIMCLR SOLAMENTE
    """
    model.train()

    loss_meter = AverageMeter()
    loss_simclr_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_f = iter(dltrain_f)
    for it in range(n_iters):
        ims_s_weak, ims_s_strong, lbs_s_real = next(
            dl_f)  # con transformaciones de simclr

        imgs = torch.cat([ims_s_weak, ims_s_strong], dim=0).cuda()
        imgs = interleave(imgs, 2 * mu)
        logits, logit_z, _ = model(imgs)
        logits_z = de_interleave(logit_z, 2 * mu)

        # SEPARACION DE REPRESENTACIONES PARA SIMCLR
        logits_s_w_z, logits_s_s_z = torch.split(logits_z, bt * mu)

        loss_s = criteria_z(logits_s_w_z, logits_s_s_z)

        loss = loss_s

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_simclr_meter.update(loss_s.item())

        if (it + 1) % 512 == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info(
                "epoch:{}, iter: {}. loss: {:.4f}. "
                " loss_simclr: {:.4f}. LR: {:.4f}. Time: {:.2f}".format(
                    epoch, it + 1, loss_meter.avg, loss_simclr_meter.avg,
                    lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()
    return loss_meter.avg, loss_simclr_meter.avg, model
def train_one_epoch_iic(epoch, model, optim, lr_schdlr, ema, dltrain_f,
                        n_iters, logger, bt, mu):
    model.train()
    loss_meter = AverageMeter()
    loss_iic_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_f = iter(dltrain_f)
    for it in range(n_iters):
        ims_s_weak, ims_s_strong, lbs_s_real = next(
            dl_f)  # con transformaciones de simclr

        imgs = torch.cat([ims_s_weak, ims_s_strong], dim=0).cuda()
        imgs = interleave(imgs, 2 * mu)
        _, _, logits_iic = model(imgs)
        logits_iic = de_interleave(logits_iic, 2 * mu)

        # SEPARACION DE ULTIMAS REPRESENTACIONES PARA SIMCLR
        logits_iic_w, logits_iic_s = torch.split(logits_iic, bt * mu)

        # loss_iic = IIC_loss(logits_s_w_h, logits_s_s_h)
        loss_iic, P = mi_loss(logits_iic_w, logits_iic_s)

        loss = loss_iic

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_iic_meter.update(loss_iic.item())

        if (it + 1) % 512 == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info("epoch:{}, iter: {}. loss: {:.4f}. "
                        " loss_iic: {:.4f}. LR: {:.4f}. Time: {:.2f}".format(
                            epoch, it + 1, loss_meter.avg, loss_iic_meter.avg,
                            lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()
    return loss_meter.avg, loss_iic_meter.avg, model
示例#4
0
    def sample(self):
        div_factor = (2**(self.scale_count // 2))
        num_mels = self.n_mels // div_factor
        timesteps = self.timesteps // div_factor
        batchsize = 1

        if self.scale_count % 2 != 0:
            num_mels //= 2

        axis = False
        output = None
        for i in range(len(self.n_layers)):
            x = torch.zeros(batchsize, timesteps, num_mels).to(self.device)
            melnet = self.melnets[i].module

            if output is not None:
                f_ext = self.f_exts[i - 1].module
                cond = f_ext(output)
            else:
                cond = None

            # Autoregression
            t = datetime.now()

            for j in range(timesteps):
                for k in range(num_mels):
                    torch.cuda.synchronize()
                    mu, sigma, pi = (item[:, j, k].float()
                                     for item in melnet(x.clone(), cond))
                    idx = pi.exp().multinomial(1)
                    x[:, j,
                      k] = torch.normal(mu, sigma).gather(-1, idx).squeeze(-1)
            print(f"Sampling Time: {datetime.now() - t}")

            if i == 0:
                output = x
            else:
                output = interleave(output, x, axis)
                _, timesteps, num_mels = output.size()
                axis = not axis

        return output
示例#5
0
    def sample(self):
        cond = None
        melnet = self.melnets[0].to(self.device)
        timesteps = self.config.n_mel * 2 // self.scale_count
        num_mels = self.config.timesteps * 2 // self.scale_count
        axis = False
        for i in range(len(self.n_layers)):
            x = torch.zeros(1, timesteps, num_mels).to(self.device)
            melnet = self.melnets[i].to(self.device)

            # Autoregression
            for _ in range(timesteps * num_mels):
                mu, sigma, pi = melnet(x, cond)
                x = sample(mu, sigma, pi)

            if i == 0:
                cond = x
            elif i != len(self.n_layers) - 1:
                cond = interleave(cond, x, axis)
                _, timesteps, num_mels = cond.size()
                axis = not axis
        return x
示例#6
0
def train_one_epoch(
    epoch,
    model,
    criteria_x,
    criteria_u,
    criteria_z,
    optim,
    lr_schdlr,
    ema,
    dltrain_x,
    dltrain_u,
    lb_guessor,
    lambda_u,
    n_iters,
    logger,
):
    model.train()
    # loss_meter, loss_x_meter, loss_u_meter, loss_u_real_meter = [], [], [], []
    loss_meter = AverageMeter()
    loss_x_meter = AverageMeter()
    loss_u_meter = AverageMeter()
    loss_u_real_meter = AverageMeter()
    loss_simclr_meter = AverageMeter()
    # the number of correctly-predicted and gradient-considered unlabeled data
    n_correct_u_lbs_meter = AverageMeter()
    # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples
    n_strong_aug_meter = AverageMeter()
    mask_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_x, dl_u = iter(dltrain_x), iter(dltrain_u)
    for it in range(n_iters):
        ims_x_weak, ims_x_strong, lbs_x = next(dl_x)
        ims_u_weak, ims_u_strong, lbs_u_real = next(dl_u)

        lbs_x = lbs_x.cuda()
        lbs_u_real = lbs_u_real.cuda()

        # --------------------------------------

        bt = ims_x_weak.size(0)
        mu = int(ims_u_weak.size(0) // bt)
        imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda()
        imgs = interleave(imgs, 2 * mu + 1)
        logits, logit_z = model(imgs)
        # logits = model(imgs)
        logits_z = de_interleave(logit_z, 2 * mu + 1)
        logits = de_interleave(logits, 2 * mu + 1)

        logits_u_w_z, logits_u_s_z = torch.split(logits_z[bt:], bt * mu)

        logits_x = logits[:bt]
        logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu)

        with torch.no_grad():
            probs = torch.softmax(logits_u_w, dim=1)
            scores, lbs_u_guess = torch.max(probs, dim=1)
            mask = scores.ge(0.95).float()

        # entrenar primero con simclr el espacio h de las imagenes separadas
        if epoch % 2 == 0:
            loss_simCLR = (criteria_z(logits_u_w_z, logits_u_s_z))

            with torch.no_grad():
                loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean()
                # loss_u = torch.zeros(1)
                loss_x = criteria_x(logits_x, lbs_x)
                # loss_x = torch.zeros(1)

            loss = loss_simCLR
        else:
            with torch.no_grad():
                loss_simCLR = (criteria_z(logits_u_w_z, logits_u_s_z))
                # loss_simCLR = torch.zeros(1)

            loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean()
            loss_x = criteria_x(logits_x, lbs_x)
            loss = loss_x + lambda_u * loss_u
        loss_u_real = (F.cross_entropy(logits_u_s, lbs_u_real) * mask).mean()

        # --------------------------------------

        # mask, lbs_u_guess = lb_guessor(model, ims_u_weak.cuda())
        # n_x = ims_x_weak.size(0)
        # ims_x_u = torch.cat([ims_x_weak, ims_u_strong]).cuda()
        # logits_x_u = model(ims_x_u)
        # logits_x, logits_u = logits_x_u[:n_x], logits_x_u[n_x:]
        # loss_x = criteria_x(logits_x, lbs_x)
        # loss_u = (criteria_u(logits_u, lbs_u_guess) * mask).mean()
        # loss = loss_x + lambda_u * loss_u
        # loss_u_real = (F.cross_entropy(logits_u, lbs_u_real) * mask).mean()

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_x_meter.update(loss_x.item())
        loss_u_meter.update(loss_u.item())
        loss_u_real_meter.update(loss_u_real.item())
        loss_simclr_meter.update(loss_simCLR.item())
        mask_meter.update(mask.mean().item())

        corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask
        n_correct_u_lbs_meter.update(corr_u_lb.sum().item())
        n_strong_aug_meter.update(mask.sum().item())

        if (it + 1) % 512 == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info(
                "epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. "
                " loss_simclr: {:.4f} n_correct_u: {:.2f}/{:.2f}. "
                "Mask:{:.4f} . LR: {:.4f}. Time: {:.2f}".format(
                    epoch, it + 1, loss_meter.avg, loss_u_meter.avg,
                    loss_x_meter.avg, loss_u_real_meter.avg,
                    loss_simclr_meter.avg, n_correct_u_lbs_meter.avg,
                    n_strong_aug_meter.avg, mask_meter.avg, lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()
    return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg,\
           loss_u_real_meter.avg, loss_simclr_meter.avg, mask_meter.avg
示例#7
0
 def format(self, names):
     return "".join(interleave(self.text, names))
示例#8
0
def test_interleave():
    x = torch.rand(960, 3, 32, 32)
    y = interleave(x, 15)
    true_y = np_interleave(x.numpy(), 15)
    assert (y.numpy() == true_y).all()
示例#9
0
def train(args):
    ## set pre-process
    dset_loaders = data_load(args)

    max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"]))
    args.max_iter = args.max_epoch * max_len

    ## set base network
    if args.net == 'resnet34':
        netG = utils.ResBase34().cuda()
    elif args.net == 'vgg16':
        netG = utils.VGG16Base().cuda()

    netF = utils.ResClassifier(class_num=args.class_num,
                               feature_dim=netG.in_features,
                               bottleneck_dim=args.bottleneck_dim).cuda()

    if len(args.gpu_id.split(',')) > 1:
        netG = nn.DataParallel(netG)

    optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1)
    optimizer_f = optim.SGD(netF.parameters(), lr=args.lr)

    base_network = nn.Sequential(netG, netF)
    source_loader_iter = iter(dset_loaders["source"])
    target_loader_iter = iter(dset_loaders["target"])
    ltarget_loader_iter = iter(dset_loaders["ltarget"])

    if args.pl.startswith('atdoc_na'):
        mem_fea = torch.rand(
            len(dset_loaders["target"].dataset) +
            len(dset_loaders["ltarget"].dataset), args.bottleneck_dim).cuda()
        mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True)
        mem_cls = torch.ones(
            len(dset_loaders["target"].dataset) +
            len(dset_loaders["ltarget"].dataset),
            args.class_num).cuda() / args.class_num

    if args.pl == 'atdoc_nc':
        mem_fea = torch.rand(args.class_num, args.bottleneck_dim).cuda()
        mem_fea = mem_fea / torch.norm(mem_fea, p=2, dim=1, keepdim=True)

    list_acc = []
    best_val_acc = 0

    for iter_num in range(1, args.max_iter + 1):
        base_network.train()
        lr_scheduler(optimizer_g,
                     init_lr=args.lr * 0.1,
                     iter_num=iter_num,
                     max_iter=args.max_iter)
        lr_scheduler(optimizer_f,
                     init_lr=args.lr,
                     iter_num=iter_num,
                     max_iter=args.max_iter)

        try:
            inputs_source, labels_source = source_loader_iter.next()
        except:
            source_loader_iter = iter(dset_loaders["source"])
            inputs_source, labels_source = source_loader_iter.next()
        try:
            inputs_target, _, target_idx = target_loader_iter.next()
        except:
            target_loader_iter = iter(dset_loaders["target"])
            inputs_target, _, target_idx = target_loader_iter.next()

        try:
            inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next()
        except:
            ltarget_loader_iter = iter(dset_loaders["ltarget"])
            inputs_ltarget, labels_ltarget, lidx = ltarget_loader_iter.next()

        inputs_lt = inputs_ltarget[0].cuda()
        inputs_lt2 = inputs_ltarget[1].cuda()
        targets_lt = torch.zeros(args.batch_size // 3,
                                 args.class_num).scatter_(
                                     1, labels_ltarget.view(-1, 1), 1)
        targets_lt = targets_lt.cuda()

        targets_s = torch.zeros(args.batch_size, args.class_num).scatter_(
            1, labels_source.view(-1, 1), 1)
        inputs_s = inputs_source.cuda()
        targets_s = targets_s.cuda()
        inputs_t = inputs_target[0].cuda()
        inputs_t2 = inputs_target[1].cuda()

        if args.pl.startswith('atdoc_na'):

            targets_u = 0
            for inp in [inputs_t, inputs_t2]:
                with torch.no_grad():
                    features_target, outputs_u = base_network(inp)

                dis = -torch.mm(features_target.detach(), mem_fea.t())
                for di in range(dis.size(0)):
                    dis[di, target_idx[di]] = torch.max(dis)
                    # dis[di, target_idx[di]+len(dset_loaders["target"].dataset)] = torch.max(dis)

                _, p1 = torch.sort(dis, dim=1)
                w = torch.zeros(features_target.size(0),
                                mem_fea.size(0)).cuda()
                for wi in range(w.size(0)):
                    for wj in range(args.K):
                        w[wi][p1[wi, wj]] = 1 / args.K

                _, pred = torch.max(w.mm(mem_cls), 1)

                targets_u += 0.5 * torch.eye(outputs_u.size(1))[pred].cuda()

        elif args.pl == 'atdoc_nc':

            targets_u = 0
            mem_fea_norm = mem_fea / torch.norm(
                mem_fea, p=2, dim=1, keepdim=True)
            for inp in [inputs_t, inputs_t2]:
                with torch.no_grad():
                    features_target, outputs_u = base_network(inp)
                dis = torch.mm(features_target.detach(), mem_fea_norm.t())
                _, pred = torch.max(dis, dim=1)
                targets_u += 0.5 * torch.eye(outputs_u.size(1))[pred].cuda()

        elif args.pl == 'npl':

            targets_u = 0
            for inp in [inputs_t, inputs_t2]:
                with torch.no_grad():
                    _, outputs_u = base_network(inp)
                _, pred = torch.max(outputs_u.detach(), 1)
                targets_u += 0.5 * torch.eye(outputs_u.size(1))[pred].cuda()

        else:
            with torch.no_grad():
                # compute guessed labels of unlabel samples
                _, outputs_u = base_network(inputs_t)
                _, outputs_u2 = base_network(inputs_t2)
                p = (torch.softmax(outputs_u, dim=1) +
                     torch.softmax(outputs_u2, dim=1)) / 2
                pt = p**(1 / args.T)
                targets_u = pt / pt.sum(dim=1, keepdim=True)
                targets_u = targets_u.detach()

        ####################################################################
        all_inputs = torch.cat(
            [inputs_s, inputs_lt, inputs_t, inputs_lt2, inputs_t2], dim=0)
        all_targets = torch.cat(
            [targets_s, targets_lt, targets_u, targets_lt, targets_u], dim=0)
        if args.alpha > 0:
            l = np.random.beta(args.alpha, args.alpha)
            l = max(l, 1 - l)
        else:
            l = 1
        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]
        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
        mixed_input = list(torch.split(mixed_input, args.batch_size))
        mixed_input = utils.interleave(mixed_input, args.batch_size)
        # s = [sa, sb, sc]
        # t1 = [t1a, t1b, t1c]
        # t2 = [t2a, t2b, t2c]
        # => s' = [sa, t1b, t2c]   t1' = [t1a, sb, t1c]   t2' = [t2a, t2b, sc]

        # _, logits = base_network(mixed_input[0])
        features, logits = base_network(mixed_input[0])
        logits = [logits]
        for input in mixed_input[1:]:
            _, temp = base_network(input)
            logits.append(temp)

        # put interleaved samples back
        # [i[:,0] for i in aa]
        logits = utils.interleave(logits, args.batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        train_criterion = utils.SemiLoss()

        Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size],
                                    logits_u, mixed_target[args.batch_size:],
                                    iter_num, args.max_iter, args.lambda_u)
        loss = Lx + w * Lu

        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        if args.pl.startswith('atdoc_na'):
            base_network.eval()
            with torch.no_grad():
                fea1, outputs1 = base_network(inputs_t)
                fea2, outputs2 = base_network(inputs_t2)
                feat = 0.5 * (fea1 + fea2)
                feat = feat / torch.norm(feat, p=2, dim=1, keepdim=True)
                softmax_out = 0.5 * (nn.Softmax(dim=1)(outputs1) +
                                     nn.Softmax(dim=1)(outputs2))
                softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0))

            mem_fea[target_idx] = (
                1.0 -
                args.momentum) * mem_fea[target_idx] + args.momentum * feat
            mem_cls[target_idx] = (1.0 - args.momentum) * mem_cls[
                target_idx] + args.momentum * softmax_out

            with torch.no_grad():
                fea1, outputs1 = base_network(inputs_lt)
                fea2, outputs2 = base_network(inputs_lt2)
                feat = 0.5 * (fea1 + fea2)
                feat = feat / torch.norm(feat, p=2, dim=1, keepdim=True)
                softmax_out = 0.5 * (nn.Softmax(dim=1)(outputs1) +
                                     nn.Softmax(dim=1)(outputs2))
                softmax_out = softmax_out**2 / ((softmax_out**2).sum(dim=0))

            mem_fea[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \
                mem_fea[lidx + len(dset_loaders["target"].dataset)] + args.momentum*feat
            mem_cls[lidx + len(dset_loaders["target"].dataset)] = (1.0 - args.momentum) * \
                mem_cls[lidx + len(dset_loaders["target"].dataset)] + args.momentum*softmax_out

        if args.pl == 'atdoc_nc':
            base_network.eval()
            with torch.no_grad():
                fea1, outputs1 = base_network(inputs_t)
                fea2, outputs2 = base_network(inputs_t2)
                feat_u = 0.5 * (fea1 + fea2)
                softmax_t = 0.5 * (nn.Softmax(dim=1)(outputs1) +
                                   nn.Softmax(dim=1)(outputs2))
                _, pred_t = torch.max(softmax_t, 1)
                onehot_tu = torch.eye(args.class_num)[pred_t].cuda()

            with torch.no_grad():
                fea1, outputs1 = base_network(inputs_lt)
                fea2, outputs2 = base_network(inputs_lt2)
                feat_l = 0.5 * (fea1 + fea2)
                softmax_t = 0.5 * (nn.Softmax(dim=1)(outputs1) +
                                   nn.Softmax(dim=1)(outputs2))
                _, pred_t = torch.max(softmax_t, 1)
                onehot_tl = torch.eye(args.class_num)[pred_t].cuda()
                # onehot_tl = torch.eye(args.class_num)[labels_ltarget].cuda()

            center_t = ((torch.mm(feat_u.t(), onehot_tu) + torch.mm(
                feat_l.t(), onehot_tl))) / (onehot_tu.sum(dim=0) +
                                            onehot_tl.sum(dim=0) + 1e-8)
            mem_fea = (1.0 - args.momentum
                       ) * mem_fea + args.momentum * center_t.t().clone()

        if iter_num % int(args.eval_epoch * max_len) == 0:
            base_network.eval()
            if args.dset == 'VISDA-C':
                acc, py, score, y, tacc = utils.cal_acc_visda(
                    dset_loaders["test"], base_network)
                args.out_file.write(tacc + '\n')
                args.out_file.flush()
            else:
                acc, py, score, y = utils.cal_acc(dset_loaders["test"],
                                                  base_network)
                val_acc, _, _, _ = utils.cal_acc(dset_loaders["val"],
                                                 base_network)

            list_acc.append(acc * 100)
            if best_val_acc <= val_acc:
                best_val_acc = val_acc
                best_acc = acc
                best_y = y
                best_py = py
                best_score = score

            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Val Acc = {:.2f}%'.format(
                args.name, iter_num, args.max_iter, acc * 100, val_acc * 100)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

    val_acc = best_acc * 100
    idx = np.argmax(np.array(list_acc))
    max_acc = list_acc[idx]
    final_acc = list_acc[-1]

    log_str = '\n==========================================\n'
    log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(
        val_acc, max_acc, final_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()

    # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt"))
    # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(),
    #     'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()})

    return base_network, py
示例#10
0
def train_mixmatch(label_loader, unlabel_loader, num_classes, model, optimizer,
                   ema_optimizer, epoch, args):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    weights = AverageMeter()
    nu = 2
    end = time.time()

    label_iter = iter(label_loader)
    unlabel_iter = iter(unlabel_loader)

    model.train()
    for i in range(args.val_iteration):

        try:
            (input, _), target = next(label_iter)
        except:
            label_iter = iter(label_loader)
            (input, _), target = next(label_iter)
        try:
            (input_ul, input1_ul), _ = next(unlabel_iter)
        except:
            unlabel_iter = iter(unlabel_loader)
            (input_ul, input1_ul), _ = next(unlabel_iter)

        bs = input.size(0)
        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.cuda(), target.cuda(non_blocking=True)
        input_ul, input1_ul = input_ul.cuda(), input1_ul.cuda()

        with torch.no_grad():
            # compute guess label
            logits = model(torch.cat([input_ul, input1_ul], dim=0))
            p = torch.nn.functional.softmax(logits, dim=-1).view(
                nu, -1, logits.shape[1])
            p_target = p.mean(dim=0).pow(1. / args.T)
            p_target /= p_target.sum(dim=1, keepdim=True)
            guess = p_target.detach_()

            assert input.shape[0] == input_ul.shape[0]

            # mixup
            target_in_onehot = torch.zeros(
                bs,
                num_classes).float().cuda().scatter_(1, target.view(-1, 1), 1)
            mixed_input, mixed_target = mixup(
                torch.cat([input] + [input_ul, input1_ul], dim=0),
                torch.cat([target_in_onehot] + [guess] * nu, dim=0),
                beta=args.beta)
            # reshape to (nu+1, bs, w, h, c)
            mixed_input = mixed_input.reshape([nu + 1] + list(input.shape))
            # reshape to (nu+1, bs)
            mixed_target = mixed_target.reshape([nu + 1] +
                                                list(target_in_onehot.shape))
            input_x, input_u = mixed_input[0], mixed_input[1:]
            target_x, target_u = mixed_target[0], mixed_target[1:]

        model.train()
        batches = interleave([input_x, input_u[0], input_u[1]], bs)
        logits = [model(batches[0])]
        for batchi in batches[1:]:
            logits.append(model(batchi))
        logits = interleave(logits, bs)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], 0)

        # loss
        # cross entropy loss for soft label
        loss_xe = torch.mean(
            torch.sum(-target_x * F.log_softmax(logits_x, dim=-1), dim=1))
        # L2 loss
        loss_l2u = F.mse_loss(F.softmax(logits_u, dim=-1),
                              target_u.reshape(nu * bs, num_classes))
        # weight for unlabeled loss with warmup
        w_match = args.lambda_u * linear_rampup(epoch + i / args.val_iteration,
                                                args.epochs)
        loss = loss_xe + w_match * loss_l2u

        # measure accuracy and record loss
        prec1, prec5 = accuracy(logits_x, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        losses_x.update(loss_xe.item(), input.size(0))
        losses_u.update(loss_l2u.item(), input.size(0))

        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))
        weights.update(w_match, input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema_optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
                  'Loss_u {loss_u.val:.4f} ({loss_u.avg:.4f})\t'
                  'Ws {ws.val:.4f}\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      args.val_iteration,
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      loss_x=losses_x,
                      loss_u=losses_u,
                      ws=weights,
                      top1=top1,
                      top5=top5))

    return top1.avg, top5.avg, losses.avg, losses_x.avg, losses_u.avg, weights.avg
示例#11
0
    def train_iter(
            self,
            model: Classifier,
            labeled_dataset: Dataset,
            unlabeled_dataset: Dataset) -> Generator[Stats, None, Any]:

        labeled_sampler = BatchSampler(RandomSampler(
            labeled_dataset, replacement=True, num_samples=self.num_iters*self.labeled_batch_size),
            batch_size=self.labeled_batch_size, drop_last=True)
        unlabeled_sampler = BatchSampler(RandomSampler(
            unlabeled_dataset, replacement=True, num_samples=self.num_iters*self.unlabeled_batch_size),
            batch_size=self.unlabeled_batch_size, drop_last=True)
        labeled_loader = DataLoader(
            labeled_dataset, batch_sampler=labeled_sampler, num_workers=self.num_workers, pin_memory=True)
        unlabeled_loader = DataLoader(
            unlabeled_dataset, batch_sampler=unlabeled_sampler, num_workers=self.num_workers, pin_memory=True)

        model.to(device=self.devices[0])
        param_avg = self.param_avg_ctor(model)

        # set up optimizer without weight decay on batch norm or bias parameters
        no_wd_filter = lambda m, k: isinstance(m, nn.BatchNorm2d) or k.endswith('bias')
        wd_filter = lambda m, k: not no_wd_filter(m, k)
        optim = self.model_optimizer_ctor([
            {'params': filter_parameters(model, wd_filter)},
            {'params': filter_parameters(model, no_wd_filter), 'weight_decay': 0.}
        ])

        scheduler = self.lr_scheduler_ctor(optim)
        scaler = torch.cuda.amp.GradScaler()

        if self.dist_alignment:
            labeled_dist = get_labeled_dist(labeled_dataset).to(self.devices[0])
            prev_labels = torch.full(
                [self.dist_alignment_batches, model.num_classes], 1 / model.num_classes, device=self.devices[0])
            prev_labels_idx = 0

        # training loop
        for batch_idx, (b_l, b_u) in enumerate(zip(labeled_loader, unlabeled_loader)):
            # labeled examples
            xl, yl = b_l
            yl = yl.cuda(non_blocking=True)

            # augmented pairs of unlabeled examples
            (xw, xs), _ = b_u

            with torch.cuda.amp.autocast(enabled=self.mixed_precision):
                x = torch.cat([xl, xs, xw]).cuda(non_blocking=True)
                num_blocks = x.shape[0] // xl.shape[0]
                x = interleave(x, num_blocks)
                out = torch.nn.parallel.data_parallel(
                    model, x, module_kwargs={'autocast': self.mixed_precision}, device_ids=self.devices)
                out = de_interleave(out, num_blocks)

                # get labels
                with torch.no_grad():
                    probs = torch.softmax(out[-len(xw):], -1)
                    if self.dist_alignment:
                        model_dist = prev_labels.mean(0)
                        prev_labels[prev_labels_idx] = probs.mean(0)
                        prev_labels_idx = (prev_labels_idx + 1) % self.dist_alignment_batches
                        probs *= (labeled_dist + self.dist_alignment_eps) / (model_dist + self.dist_alignment_eps)
                        probs /= probs.sum(-1, keepdim=True)
                    yu = torch.argmax(probs, -1)
                    mask = (torch.max(probs, -1)[0] >= self.threshold).to(dtype=torch.float32)

                loss_l = F.cross_entropy(out[:len(xl)], yl, reduction='mean')
                loss_u = (mask * F.cross_entropy(out[len(xl):-len(xw)], yu, reduction='none')).mean()
                loss = loss_l + self.unlabeled_weight * loss_u

            model.zero_grad()
            if self.mixed_precision:
                scaler.scale(loss).backward()
                scaler.step(optim)
                scaler.update()
            else:
                loss.backward()
                optim.step()
            param_avg.step()
            scheduler.step()

            yield self.Stats(
                iter=batch_idx+1,
                loss=loss.cpu().item(),
                loss_labeled=loss_l.cpu().item(),
                loss_unlabeled=loss_u.cpu().item(),
                model=model,
                avg_model=param_avg.avg_model,
                optimizer=optim,
                scheduler=scheduler,
                threshold_frac=mask.mean().cpu().item())
示例#12
0
    def train_iter(
            self,
            model: Classifier,
            num_classes: int,
            labeled_dataset: Dataset,
            unlabeled_dataset: Dataset) -> Generator[Stats, None, Any]:

        labeled_sampler = BatchSampler(RandomSampler(
            labeled_dataset, replacement=True, num_samples=self.batches_per_epoch * self.labeled_batch_size),
            batch_size=self.labeled_batch_size, drop_last=True)
        labeled_loader = DataLoader(
            labeled_dataset, batch_sampler=labeled_sampler, num_workers=self.num_workers, pin_memory=True)
        unlabeled_sampler = BatchSampler(RandomSampler(
            unlabeled_dataset, replacement=True,
            num_samples=self.batches_per_epoch * self.unlabeled_batch_size),
            batch_size=self.unlabeled_batch_size, drop_last=True)
        unlabeled_loader = DataLoader(
            unlabeled_dataset, batch_sampler=unlabeled_sampler, num_workers=self.num_workers, pin_memory=True)

        # initialize model and optimizer
        model.to(device=self.devices[0])
        param_avg = self.param_avg_ctor(model)

        # set up optimizer without weight decay on batch norm or bias parameters
        no_wd_filter = lambda m, k: isinstance(m, nn.BatchNorm2d) or k.endswith('bias')
        wd_filter = lambda m, k: not no_wd_filter(m, k)
        optim = self.model_optimizer_ctor([
            {'params': filter_parameters(model, wd_filter)},
            {'params': filter_parameters(model, no_wd_filter), 'weight_decay': 0.}
        ])

        scheduler = self.lr_scheduler_ctor(optim)
        scaler = torch.cuda.amp.GradScaler()

        # initialize label assignment
        log_upper_bounds = get_log_upper_bounds(
            labeled_dataset, method=self.upper_bound_method, **self.upper_bound_kwargs)
        logger.info('upper bounds = {}'.format(torch.exp(log_upper_bounds)))
        label_assgn = SinkhornLabelAllocation(
            num_examples=len(unlabeled_dataset),
            log_upper_bounds=log_upper_bounds,
            allocation_param=0.,
            entropy_reg=self.entropy_reg,
            update_tol=self.update_tol,
            device=self.devices[0])

        # training loop
        for epoch in range(self.num_epochs):
            # (1) update model
            for batch_idx, (b_l, b_u) in enumerate(zip(labeled_loader, unlabeled_loader)):
                # labeled examples
                xl, yl = b_l
                yl = yl.cuda(non_blocking=True)

                # augmented pairs of unlabeled examples
                (xu1, xu2), idxs = b_u

                with torch.cuda.amp.autocast(enabled=self.mixed_precision):
                    x = torch.cat([xl, xu1, xu2]).cuda(non_blocking=True)
                    if len(self.devices) > 1:
                        num_blocks = x.shape[0] // xl.shape[0]
                        x = interleave(x, num_blocks)
                        out = torch.nn.parallel.data_parallel(
                            model, x, module_kwargs={'autocast': self.mixed_precision}, device_ids=self.devices)
                        out = de_interleave(out, num_blocks)
                    else:
                        out = model(x, autocast=self.mixed_precision)

                    # compute labels
                    logp_u = F.log_softmax(out[len(xl):], -1)
                    nu = logp_u.shape[0] // 2
                    qu = label_assgn.get_plan(log_p=logp_u[:nu].detach()).to(dtype=torch.float32, device=out.device)
                    qu = qu[:, :-1]

                    # compute loss
                    loss_l = F.cross_entropy(out[:len(xl)], yl, reduction='mean')
                    loss_u = -(qu * logp_u[nu:]).sum(-1).mean()
                    loss = loss_l + self.unlabeled_weight * loss_u

                    # update plan
                    rho = self.allocation_schedule(
                        (epoch * self.batches_per_epoch + batch_idx + 1) /
                        (self.num_epochs * self.batches_per_epoch))
                    label_assgn.set_allocation_param(rho)
                    label_assgn.update_loss_matrix(logp_u[:nu], idxs)
                    assgn_err, assgn_iters = label_assgn.update()

                optim.zero_grad()
                if self.mixed_precision:
                    scaler.scale(loss).backward()
                    scaler.step(optim)
                    scaler.update()
                else:
                    loss.backward()
                    optim.step()
                param_avg.step()
                scheduler.step()

                yield self.Stats(
                    iter=epoch * self.batches_per_epoch + batch_idx + 1,
                    loss=loss.cpu().item(),
                    loss_labeled=loss_l.cpu().item(),
                    loss_unlabeled=loss_u.cpu().item(),
                    model=model,
                    avg_model=param_avg.avg_model,
                    allocation_param=rho,
                    optimizer=optim,
                    scheduler=scheduler,
                    label_vars=qu,
                    scaling_vars=label_assgn.v.data,
                    assgn_err=assgn_err,
                    assgn_iters=assgn_iters)
def find_positive_set(iterations, tolerance_a, tolerance_b, param_boundaries):
    """
    Find the set of boundaries that should be classified positively.

    First, find the upper and lower edge bounds. Then, move the midpoints of 
    positively-classified boundaries out until they are negatively classified.
    Recurse for the midpoints of the line segments created in this process for
    some number of iterations.

    Argument types:
    iterations -- integer
    tolerance_a -- float. Should be positive.
    tolerance_b -- float. Should be positive.
    param_boundaries -- list of form [[lower_bound_x, upper_bound_x], 
                                      [lower_bound_y, upper_bound_y]],
                        describing the dimensions of parameter space

    note: should really separate out tolerance for how finely we're generating
    boundaries vs tolerance for how careful we are about where the boundary of 
    the convex set is. at the moment we'll just call them tolerance_a and 
    tolerance_b, need better names
    """
    assert isinstance(iterations, int) or isinstance(
        iterations, long), "iterations isn't an integer in find_positive_set"
    assert iterations > 0, "you can't have a non-positive number of iterations in find_positive_set"
    assert isinstance(
        tolerance_a,
        float), "tolerance_a should be a float in find_positive_set"
    assert tolerance_a > 0, "tolerance_a should be positive in find_positive_set"
    assert isinstance(
        tolerance_b,
        float), "tolerance_b should be a float in find_positive_set"
    assert tolerance_b > 0, "tolerance_b should be positive in find_positive_set"

    # get a positive example
    positive_example = get_positive_example(param_boundaries)

    print("Positive example gained.")

    # find where the endpoints can be in the convex set
    endpoint_bounds = find_endpoint_bounds(positive_example, tolerance_a,
                                           tolerance_b, param_boundaries)
    upper_bound = [endpoint_bounds[0], endpoint_bounds[1]]
    lower_bound = [endpoint_bounds[2], endpoint_bounds[3]]

    print("after finding endpoint bounds, upper bound is", upper_bound)
    print("after finding endpoint bounds, lower bound is", lower_bound)

    # repeatedly find the midpoints of the line segments in the lower and upper
    # bounds, and move them out as far as possible
    for i in range(iterations):
        print("in loop number " + str(i) + " when moving out midpoints")
        assert len(upper_bound) == len(
            lower_bound
        ), "somehow upper and lower bounds became a different length in the loop of find_positive_set"
        upper_extensions = []
        lower_extensions = []
        for j in range(len(upper_bound) - 1):
            new_upper = maximally_extend_segment(upper_bound, j, tolerance_a,
                                                 tolerance_b, True)
            new_lower = maximally_extend_segment(lower_bound, j, tolerance_a,
                                                 tolerance_b, False)
            upper_extensions.append(new_upper)
            lower_extensions.append(new_lower)
        assert len(upper_extensions) == len(
            lower_extensions
        ), "upper_extensions and lower_extensions should end up having the same length after the end of the inner loop of find_positive_set"
        assert len(upper_extensions) == len(
            upper_bound
        ) - 1, "upper_extensions should have length 1 less than upper_bound after the inner loop of find_positive_set"
        upper_bound = utils.interleave(upper_extensions, upper_bound)
        lower_bound = utils.interleave(lower_extensions, lower_bound)
    utils.plot(upper_bound, 'tab:orange', param_boundaries)
    utils.plot(lower_bound, 'tab:orange', param_boundaries)
    return (upper_bound, lower_bound)
示例#14
0
def test_interleave_simple():
    x = torch.arange(960)[:, None]
    y = interleave(x, 15)
    true_y = np_interleave(x.numpy(), 15)
    assert (y.numpy() == true_y).all()
示例#15
0
def train(args, txt_src, txt_tgt):
    ## set pre-process
    dset_loaders = data_load(args, txt_src, txt_tgt)
    # pdb.set_trace()
    max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"]))
    max_iter = args.max_epoch * max_len
    interval_iter = max_iter // 10

    if args.dset == 'u2m':
        netG = network.LeNetBase().cuda()
    elif args.dset == 'm2u':
        netG = network.LeNetBase().cuda()
    elif args.dset == 's2m':
        netG = network.DTNBase().cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netG.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    if args.model == 'source':
        modelpath = args.output_dir + "/source_F.pt"
        netG.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_B.pt"
        netB.load_state_dict(torch.load(modelpath))
    else:
        modelpath = args.output_dir + "/target_F_" + args.savename + ".pt"
        netG.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_B_" + args.savename + ".pt"
        netB.load_state_dict(torch.load(modelpath))

    netF = nn.Sequential(netB, netC)
    optimizer_g = optim.SGD(netG.parameters(), lr=args.lr * 0.1)
    optimizer_f = optim.SGD(netF.parameters(), lr=args.lr)

    base_network = nn.Sequential(netG, netF)
    source_loader_iter = iter(dset_loaders["source"])
    target_loader_iter = iter(dset_loaders["target"])

    list_acc = []
    best_ent = 100

    for iter_num in range(1, max_iter + 1):
        base_network.train()
        lr_scheduler(optimizer_g,
                     init_lr=args.lr * 0.1,
                     iter_num=iter_num,
                     max_iter=max_iter)
        lr_scheduler(optimizer_f,
                     init_lr=args.lr,
                     iter_num=iter_num,
                     max_iter=max_iter)

        try:
            inputs_source, labels_source = source_loader_iter.next()
        except:
            source_loader_iter = iter(dset_loaders["source"])
            inputs_source, labels_source = source_loader_iter.next()
        try:
            inputs_target, _, target_idx = target_loader_iter.next()
        except:
            target_loader_iter = iter(dset_loaders["target"])
            inputs_target, _, target_idx = target_loader_iter.next()

        targets_s = torch.zeros(args.batch_size, args.class_num).scatter_(
            1, labels_source.view(-1, 1), 1)
        inputs_s = inputs_source.cuda()
        targets_s = targets_s.cuda()
        inputs_t = inputs_target[0].cuda()
        inputs_t2 = inputs_target[1].cuda()

        with torch.no_grad():
            # compute guessed labels of unlabel samples
            outputs_u = base_network(inputs_t)
            outputs_u2 = base_network(inputs_t2)
            p = (torch.softmax(outputs_u, dim=1) +
                 torch.softmax(outputs_u2, dim=1)) / 2
            pt = p**(1 / args.T)
            targets_u = pt / pt.sum(dim=1, keepdim=True)
            targets_u = targets_u.detach()

        ####################################################################
        all_inputs = torch.cat([inputs_s, inputs_t, inputs_t2], dim=0)
        all_targets = torch.cat([targets_s, targets_u, targets_u], dim=0)
        if args.alpha > 0:
            l = np.random.beta(args.alpha, args.alpha)
            l = max(l, 1 - l)
        else:
            l = 1
        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]
        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
        mixed_input = list(torch.split(mixed_input, args.batch_size))
        mixed_input = utils.interleave(mixed_input, args.batch_size)
        # s = [sa, sb, sc]
        # t1 = [t1a, t1b, t1c]
        # t2 = [t2a, t2b, t2c]
        # => s' = [sa, t1b, t2c]   t1' = [t1a, sb, t1c]   t2' = [t2a, t2b, sc]

        logits = base_network(mixed_input[0])
        logits = [logits]
        for input in mixed_input[1:]:
            temp = base_network(input)
            logits.append(temp)

        # put interleaved samples back
        # [i[:,0] for i in aa]
        logits = utils.interleave(logits, args.batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        train_criterion = utils.SemiLoss()

        Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size],
                                    logits_u, mixed_target[args.batch_size:],
                                    iter_num, max_iter, args.lambda_u)
        loss = Lx + w * Lu

        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            base_network.eval()

            acc, py, score, y = cal_acc(dset_loaders["train"],
                                        base_network,
                                        flag=False)
            mean_ent = torch.mean(Entropy(score))
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
                args.dset + '_train', iter_num, max_iter, acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            acc, py, score, y = cal_acc(dset_loaders["test"],
                                        base_network,
                                        flag=False)
            mean_ent = torch.mean(Entropy(score))
            list_acc.append(acc)

            if best_ent > mean_ent:
                val_acc = acc
                best_ent = mean_ent
                best_y = y
                best_py = py
                best_score = score

            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(
                args.dset + '_test', iter_num, max_iter, acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

    idx = np.argmax(np.array(list_acc))
    max_acc = list_acc[idx]
    final_acc = list_acc[-1]

    log_str = '\n==========================================\n'
    log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(
        val_acc, max_acc, final_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()

    # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt"))
    # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(),
    #     'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()})

    return base_network, py
示例#16
0
            batchX = next(iteratorX)
        try:
            batchU = next(iteratorU)
        except:
            iteratorU = iter(shuffle_and_batch(trainU))
            batchU = next(iteratorU)

        #args['beta'].assign(np.random.beta(args['alpha'], args['alpha']))
        beta = np.random.beta(0.75,0.75)
        with tf.GradientTape() as tape:
            # run mixmatch
            XU, XUy = mixmatch(model, batchX['image'], batchX['label'], batchU['image'], 0.5, 2, beta)
            logits = [model(XU[0])]
            for batch in XU[1:]:
                logits.append(model(batch))
            logits = interleave(logits, 64)
            logits_x = logits[0]
            logits_u = tf.concat(logits[1:], axis=0)

            # compute loss
            xe_loss, l2u_loss = semi_loss(XUy[:64], logits_x, XUy[64:], logits_u)
            total_loss = xe_loss + lambda_u * l2u_loss

        # compute gradients and run optimizer step
        grads = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        ema(model, ema_model, 0.999)
        weight_decay(model=model, decay_rate=0.02 * 0.002)

        xe_loss_avg(xe_loss)
        l2u_loss_avg(l2u_loss)
示例#17
0
 def set(self, observations, actions):
     return Memoizer(self.agent, self.cache,
                     interleave(observations, actions))
 def format_with(self, field_strings):
     return "".join(utils.interleave(self.text, field_strings))
def train_one_epoch(epoch, model, criteria_x, criteria_u, criteria_z, optim,
                    lr_schdlr, ema, dltrain_x, dltrain_u, dltrain_f,
                    lb_guessor, lambda_u, lambda_s, n_iters, logger, bt, mu):
    """
    FUNCION DE ENTRENAMIENTO PARA FIXMATCH Y SIMCLR EN LA MISMA EPOCA
    """
    model.train()
    # loss_meter, loss_x_meter, loss_u_meter, loss_u_real_meter = [], [], [], []
    loss_meter = AverageMeter()
    loss_x_meter = AverageMeter()
    loss_u_meter = AverageMeter()
    loss_u_real_meter = AverageMeter()
    loss_simclr_meter = AverageMeter()
    # the number of correctly-predicted and gradient-considered unlabeled data
    n_correct_u_lbs_meter = AverageMeter()
    # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples
    n_strong_aug_meter = AverageMeter()
    mask_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_x, dl_u, dl_f = iter(dltrain_x), iter(dltrain_u), iter(dltrain_f)
    # dl_x, dl_u = iter(dltrain_x), iter(dltrain_u)
    for it in range(n_iters):
        ims_x_weak, ims_x_strong, lbs_x = next(dl_x)
        ims_u_weak, ims_u_strong, lbs_u_real = next(
            dl_u)  # transformaciones de fixmatch
        ims_s_weak, ims_s_strong, lbs_s_real = next(
            dl_f)  # con transformaciones de simclr

        lbs_x = lbs_x.cuda()
        lbs_u_real = lbs_u_real.cuda()

        # --------------------------------------
        imgs = torch.cat(
            [ims_x_weak, ims_u_weak, ims_u_strong, ims_s_weak, ims_s_strong],
            dim=0).cuda()
        # imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda()
        imgs = interleave(imgs, 4 * mu + 1)
        # imgs = interleave(imgs, 2 * mu + 1)
        logits, logit_z, _ = model(imgs)
        logits = de_interleave(logits, 4 * mu + 1)
        # logits = de_interleave(logits, 2 * mu + 1)

        # SEPARACION DE LOGITS PARA ETAPA SUPERVISADA DE FIXMATCH
        logits_x = logits[:bt]
        # SEPARACION DE LOGITS PARA ETAPA NO SUPERVISADA DE FIXMATCH
        logits_u_w, logits_u_s, _, _ = torch.split(logits[bt:], bt * mu)
        # SEPARACION DE LOGITS PARA ETAPA NO SUPERVISADA DE SIMCLR
        _, _, logits_s_w, logits_s_s = torch.split(logit_z[bt:], bt * mu)

        # calculo de la mascara con transformacion debil de fixmatch
        with torch.no_grad():
            probs = torch.softmax(logits_u_w, dim=1)
            scores, lbs_u_guess = torch.max(probs, dim=1)
            mask = scores.ge(0.95).float()

        # calcular perdida
        loss_s = criteria_z(logits_s_w, logits_s_s)
        loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean()
        loss_x = criteria_x(logits_x, lbs_x)

        loss = loss_x + loss_u * lambda_u + loss_s * lambda_s

        loss_u_real = (F.cross_entropy(logits_u_s, lbs_u_real) * mask).mean()

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_x_meter.update(loss_x.item())
        loss_u_meter.update(loss_u.item())
        loss_u_real_meter.update(loss_u_real.item())
        loss_simclr_meter.update(loss_s.item())
        mask_meter.update(mask.mean().item())

        corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask
        n_correct_u_lbs_meter.update(corr_u_lb.sum().item())
        n_strong_aug_meter.update(mask.sum().item())

        if (it + 1) % 512 == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info(
                "epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. "
                "n_correct_u: {:.2f}/{:.2f}. loss_s: {:.4f}. "
                "Mask:{:.4f}. LR: {:.4f}. Time: {:.2f}".format(
                    epoch, it + 1, loss_meter.avg, loss_u_meter.avg,
                    loss_x_meter.avg, loss_u_real_meter.avg,
                    n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg,
                    loss_simclr_meter.avg, mask_meter.avg, lr_log, t))

            # logger.info("epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. "
            #             "n_correct_u: {:.2f}/{:.2f}."
            #             "Mask:{:.4f} . LR: {:.4f}. Time: {:.2f}".format(
            #     epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, loss_u_real_meter.avg,
            #     n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()
    return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg,\
           loss_u_real_meter.avg, mask_meter.avg, loss_simclr_meter.avg
示例#20
0
def test_interleave_deinterleave():
    x = torch.rand(960, 3, 32, 32)
    y = interleave(x, 15)
    z = deinterleave(y, 15)
    assert (z == x).all()