Esempio n. 1
0
def train_epoch(train_loader, model_list, optimizer_list, epoch, log):
    global global_step

    meters = AverageMeterSet()

    # define criterions
    class_criterion = nn.CrossEntropyLoss(size_average=False,
                                          ignore_index=NO_LABEL).cuda()
    residual_logit_criterion = losses.symmetric_mse_loss
    if args.consistency_type == 'mse':
        consistency_criterion = losses.softmax_mse_loss
        stabilization_criterion = losses.softmax_mse_loss
    elif args.consistency_type == 'kl':
        consistency_criterion = losses.softmax_kl_loss
        stabilization_criterion = losses.softmax_kl_loss

    for model in model_list:
        model.train()

    end = time.time()
    for i, (input_list, target) in enumerate(train_loader):
        meters.update('data_time', time.time() - end)

        for odx, optimizer in enumerate(optimizer_list):
            adjust_learning_rate(optimizer, epoch, i, len(train_loader))
            meters.update('lr_{0}'.format(odx),
                          optimizer.param_groups[0]['lr'])

        input_var_list, nograd_input_var_list = [], []
        for idx, inp in enumerate(input_list):
            input_var_list.append(Variable(inp))
            nograd_input_var_list.append(
                Variable(inp, requires_grad=False, volatile=True))

        target_var = Variable(target.cuda(async=True))

        minibatch_size = len(target_var)
        labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum()
        unlabeled_minibatch_size = minibatch_size - labeled_minibatch_size
        assert labeled_minibatch_size >= 0 and unlabeled_minibatch_size >= 0
        meters.update('labeled_minibatch_size', labeled_minibatch_size)
        meters.update('unlabeled_minibatch_size', unlabeled_minibatch_size)

        loss_list = []
        cls_v_list, nograd_cls_v_list = [], []
        cls_i_list, nograd_cls_i_list = [], []
        mask_list, nograd_mask_list = [], []
        class_logit_list, nograd_class_logit_list = [], []
        cons_logit_list = []
        in_cons_logit_list, tar_class_logit_list = [], []

        # for each student model
        for mdx, model in enumerate(model_list):
            # forward
            class_logit, cons_logit = model(input_var_list[mdx])
            nograd_class_logit, nograd_cons_logit = model(
                nograd_input_var_list[mdx])

            # calculate - res_loss, class_loss, consistency_loss - inside each student model
            res_loss = args.logit_distance_cost * residual_logit_criterion(
                class_logit, cons_logit) / minibatch_size
            meters.update('{0}_res_loss'.format(mdx), res_loss.data[0])

            class_loss = class_criterion(class_logit,
                                         target_var) / minibatch_size
            meters.update('{0}_class_loss'.format(mdx), res_loss.data[0])

            consistency_weight = args.consistency_scale * ramps.sigmoid_rampup(
                epoch, args.consistency_rampup)
            nograd_class_logit = Variable(nograd_class_logit.detach().data,
                                          requires_grad=False)
            consistency_loss = consistency_weight * consistency_criterion(
                cons_logit, nograd_class_logit) / minibatch_size
            meters.update('{0}_cons_loss'.format(mdx),
                          consistency_loss.data[0])

            loss = class_loss + res_loss + consistency_loss
            loss_list.append(loss)

            # store variables for calculating the stabilization loss
            cls_v, cls_i = torch.max(F.softmax(class_logit, dim=1), dim=1)
            nograd_cls_v, nograd_cls_i = torch.max(F.softmax(
                nograd_class_logit, dim=1),
                                                   dim=1)
            cls_v_list.append(cls_v)
            cls_i_list.append(cls_i.data.cpu().numpy())
            nograd_cls_v_list.append(nograd_cls_v)
            nograd_cls_i_list.append(nograd_cls_i.data.cpu().numpy())

            mask_raw = torch.max(F.softmax(class_logit, dim=1), 1)[0]
            mask = (mask_raw > args.stable_threshold)
            nograd_mask_raw = torch.max(F.softmax(nograd_class_logit, dim=1),
                                        1)[0]
            nograd_mask = (nograd_mask_raw > args.stable_threshold)
            mask_list.append(mask.data.cpu().numpy())
            nograd_mask_list.append(nograd_mask.data.cpu().numpy())

            class_logit_list.append(class_logit)
            cons_logit_list.append(cons_logit)
            nograd_class_logit_list.append(nograd_class_logit)

            in_cons_logit = Variable(cons_logit.detach().data,
                                     requires_grad=False)
            in_cons_logit_list.append(in_cons_logit)

            tar_class_logit = Variable(class_logit.clone().detach().data,
                                       requires_grad=False)
            tar_class_logit_list.append(tar_class_logit)

        # calculate stablization weight
        stabilization_weight = args.stabilization_scale * ramps.sigmoid_rampup(
            epoch, args.stabilization_rampup)
        if not args.exclude_unlabeled:
            stabilization_weight = (unlabeled_minibatch_size /
                                    minibatch_size) * stabilization_weight

        model_idx = np.arange(0, len(model_list))
        np.random.shuffle(model_idx)

        for idx in range(0, len(model_idx)):
            if idx % 2 != 0:
                continue

            # l and r construct Dual Student
            l_mdx, r_mdx = model_idx[idx], model_idx[idx + 1]

            for sdx in range(0, minibatch_size):
                l_stable = False
                # unstable: do not satisfy the 2nd condition
                if mask_list[l_mdx][sdx] == 0 and nograd_mask_list[l_mdx][
                        sdx] == 0:
                    tar_class_logit_list[l_mdx][
                        sdx, ...] = in_cons_logit_list[r_mdx][sdx, ...]
                # unstable: do not satisfy the 1st condition
                elif cls_i_list[l_mdx][sdx] != nograd_cls_i_list[l_mdx][sdx]:
                    tar_class_logit_list[l_mdx][
                        sdx, ...] = in_cons_logit_list[r_mdx][sdx, ...]
                else:
                    l_stable = True

                r_stable = False
                # unstable: do not satisfy the 2nd condition
                if mask_list[r_mdx][sdx] == 0 and nograd_mask_list[r_mdx][
                        sdx] == 0:
                    tar_class_logit_list[r_mdx][
                        sdx, ...] = in_cons_logit_list[l_mdx][sdx, ...]
                # unstable: do not satisfy the 1st condition
                elif cls_i_list[r_mdx][sdx] != nograd_cls_i_list[r_mdx][sdx]:
                    tar_class_logit_list[r_mdx][
                        sdx, ...] = in_cons_logit_list[l_mdx][sdx, ...]
                else:
                    r_stable = True

            # calculate stability if both l and r models are stable for a sample
            if l_stable and r_stable:
                l_sample_cons = consistency_criterion(
                    cons_logit_list[l_mdx][sdx:sdx + 1, ...],
                    nograd_class_logit_list[r_mdx][sdx:sdx + 1, ...])
                r_sample_cons = consistency_criterion(
                    cons_logit_list[r_mdx][sdx:sdx + 1, ...],
                    nograd_class_logit_list[l_mdx][sdx:sdx + 1, ...])
                # loss: l -> r
                if l_sample_cons.data.cpu().numpy(
                )[0] < r_sample_cons.data.cpu().numpy()[0]:
                    tar_class_logit_list[r_mdx][
                        sdx, ...] = in_cons_logit_list[l_mdx][sdx, ...]
                # loss: r -> l
                elif l_sample_cons.data.cpu().numpy(
                )[0] > r_sample_cons.data.cpu().numpy()[0]:
                    tar_class_logit_list[l_mdx][
                        sdx, ...] = in_cons_logit_list[r_mdx][sdx, ...]

            if args.exclude_unlabeled:
                l_stabilization_loss = stabilization_weight * stabilization_criterion(
                    cons_logit_list[l_mdx],
                    tar_class_logit_list[r_mdx]) / minibatch_size
                r_stabilization_loss = stabilization_weight * stabilization_criterion(
                    cons_logit_list[r_mdx],
                    tar_class_logit_list[l_mdx]) / minibatch_size
            else:
                for sdx in range(unlabeled_minibatch_size, minibatch_size):
                    tar_class_logit_list[l_mdx][
                        sdx, ...] = in_cons_logit_list[r_mdx][sdx, ...]
                    tar_class_logit_list[r_mdx][
                        sdx, ...] = in_cons_logit_list[l_mdx][sdx, ...]

                l_stabilization_loss = stabilization_weight * stabilization_criterion(
                    cons_logit_list[l_mdx],
                    tar_class_logit_list[r_mdx]) / unlabeled_minibatch_size
                r_stabilization_loss = stabilization_weight * stabilization_criterion(
                    cons_logit_list[r_mdx],
                    tar_class_logit_list[l_mdx]) / unlabeled_minibatch_size

            meters.update('{0}_stable_loss'.format(l_mdx),
                          l_stabilization_loss.data[0])
            meters.update('{0}_stable_loss'.format(r_mdx),
                          r_stabilization_loss.data[0])

            loss_list[l_mdx] = loss_list[l_mdx] + l_stabilization_loss
            loss_list[r_mdx] = loss_list[r_mdx] + r_stabilization_loss

            meters.update('{0}_loss'.format(l_mdx), loss_list[l_mdx].data[0])
            meters.update('{0}_loss'.format(r_mdx), loss_list[r_mdx].data[0])

        for mdx, model in enumerate(model_list):
            # calculate prec
            prec = mt_func.accuracy(class_logit_list[mdx].data,
                                    target_var.data,
                                    topk=(1, ))[0]
            meters.update('{0}_top1'.format(mdx), prec[0],
                          labeled_minibatch_size)

            # backward and update
            optimizer_list[mdx].zero_grad()
            loss_list[mdx].backward()
            optimizer_list[mdx].step()

        # record
        global_step += 1
        meters.update('batch_time', time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            LOG.info('Epoch: [{0}][{1}/{2}]\t'
                     'Batch-T {meters[batch_time]:.3f}\t'.format(
                         epoch, i, len(train_loader), meters=meters))

            for mdx, model in enumerate(model_list):
                cur_class_loss = meters['{0}_class_loss'.format(mdx)].val
                avg_class_loss = meters['{0}_class_loss'.format(mdx)].avg
                cur_res_loss = meters['{0}_res_loss'.format(mdx)].val
                avg_res_loss = meters['{0}_res_loss'.format(mdx)].avg
                cur_cons_loss = meters['{0}_cons_loss'.format(mdx)].val
                avg_cons_loss = meters['{0}_cons_loss'.format(mdx)].avg
                cur_stable_loss = meters['{0}_stable_loss'.format(mdx)].val
                avg_stable_loss = meters['{0}_stable_loss'.format(mdx)].avg
                cur_top1_acc = meters['{0}_top1'.format(mdx)].val
                avg_top1_acc = meters['{0}_top1'.format(mdx)].avg
                LOG.info(
                    'model-{0}: Class {1:.4f}({2:.4f})\tRes {3:.4f}({4:.4f})\tCons {5:.4f}({6:.4f})\t'
                    'Stable {7:.4f}({8:.4f})\tPrec@1 {9:.3f}({10:.3f})\t'.
                    format(mdx, cur_class_loss, avg_class_loss, cur_res_loss,
                           avg_res_loss, cur_cons_loss, avg_cons_loss,
                           cur_stable_loss, avg_stable_loss, cur_top1_acc,
                           avg_top1_acc))

            LOG.info('\n')
            log.record(
                epoch + i / len(train_loader), {
                    'step': global_step,
                    **meters.values(),
                    **meters.averages(),
                    **meters.sums()
                })
Esempio n. 2
0
def train_epoch(train_loader, l_model, r_model, l_optimizer, r_optimizer,
                epoch, log):
    global global_step

    meters = AverageMeterSet()

    # define criterions
    class_criterion = nn.CrossEntropyLoss(size_average=False,
                                          ignore_index=NO_LABEL).cuda()
    residual_logit_criterion = losses.symmetric_mse_loss
    if args.consistency_type == 'mse':
        consistency_criterion = losses.softmax_mse_loss
        stabilization_criterion = losses.softmax_mse_loss
    elif args.consistency_type == 'kl':
        consistency_criterion = losses.softmax_kl_loss
        stabilization_criterion = losses.softmax_kl_loss

    l_model.train()
    r_model.train()

    end = time.time()
    for i, ((l_input, r_input), target) in enumerate(train_loader):
        meters.update('data_time', time.time() - end)

        # adjust learning rate
        adjust_learning_rate(l_optimizer, epoch, i, len(train_loader))
        adjust_learning_rate(r_optimizer, epoch, i, len(train_loader))
        meters.update('l_lr', l_optimizer.param_groups[0]['lr'])
        meters.update('r_lr', r_optimizer.param_groups[0]['lr'])

        # prepare data
        l_input_var = Variable(l_input)
        r_input_var = Variable(r_input)
        le_input_var = Variable(r_input, requires_grad=False, volatile=True)
        re_input_var = Variable(l_input, requires_grad=False, volatile=True)
        target_var = Variable(target.cuda(async=True))

        minibatch_size = len(target_var)
        labeled_minibatch_size = target_var.data.ne(NO_LABEL).sum()
        unlabeled_minibatch_size = minibatch_size - labeled_minibatch_size
        assert labeled_minibatch_size >= 0 and unlabeled_minibatch_size >= 0
        meters.update('labeled_minibatch_size', labeled_minibatch_size)
        meters.update('unlabeled_minibatch_size', unlabeled_minibatch_size)

        # forward
        l_model_out = l_model(l_input_var)
        r_model_out = r_model(r_input_var)
        le_model_out = l_model(le_input_var)
        re_model_out = r_model(re_input_var)

        if isinstance(l_model_out, Variable):
            assert args.logit_distance_cost < 0
            l_logit1 = l_model_out
            r_logit1 = r_model_out
            le_logit1 = le_model_out
            re_logit1 = re_model_out
        elif len(l_model_out) == 2:
            assert len(r_model_out) == 2
            l_logit1, l_logit2 = l_model_out
            r_logit1, r_logit2 = r_model_out
            le_logit1, le_logit2 = le_model_out
            re_logit1, re_logit2 = re_model_out

        # logit distance loss from mean teacher
        if args.logit_distance_cost >= 0:
            l_class_logit, l_cons_logit = l_logit1, l_logit2
            r_class_logit, r_cons_logit = r_logit1, r_logit2
            le_class_logit, le_cons_logit = le_logit1, le_logit2
            re_class_logit, re_cons_logit = re_logit1, re_logit2

            l_res_loss = args.logit_distance_cost * residual_logit_criterion(
                l_class_logit, l_cons_logit) / minibatch_size
            r_res_loss = args.logit_distance_cost * residual_logit_criterion(
                r_class_logit, r_cons_logit) / minibatch_size
            meters.update('l_res_loss', l_res_loss.data[0])
            meters.update('r_res_loss', r_res_loss.data[0])
        else:
            l_class_logit, l_cons_logit = l_logit1, l_logit1
            r_class_logit, r_cons_logit = r_logit1, r_logit1
            le_class_logit, le_cons_logit = le_logit1, le_logit1
            re_class_logit, re_cons_logit = re_logit1, re_logit1

            l_res_loss = 0.0
            r_res_loss = 0.0
            meters.update('l_res_loss', 0.0)
            meters.update('r_res_loss', 0.0)

        # classification loss
        l_class_loss = class_criterion(l_class_logit,
                                       target_var) / minibatch_size
        r_class_loss = class_criterion(r_class_logit,
                                       target_var) / minibatch_size
        meters.update('l_class_loss', l_class_loss.data[0])
        meters.update('r_class_loss', r_class_loss.data[0])

        l_loss, r_loss = l_class_loss, r_class_loss
        l_loss += l_res_loss
        r_loss += r_res_loss

        # consistency loss
        consistency_weight = args.consistency_scale * ramps.sigmoid_rampup(
            epoch, args.consistency_rampup)

        le_class_logit = Variable(le_class_logit.detach().data,
                                  requires_grad=False)
        l_consistency_loss = consistency_weight * consistency_criterion(
            l_cons_logit, le_class_logit) / minibatch_size
        meters.update('l_cons_loss', l_consistency_loss.data[0])
        l_loss += l_consistency_loss

        re_class_logit = Variable(re_class_logit.detach().data,
                                  requires_grad=False)
        r_consistency_loss = consistency_weight * consistency_criterion(
            r_cons_logit, re_class_logit) / minibatch_size
        meters.update('r_cons_loss', r_consistency_loss.data[0])
        r_loss += r_consistency_loss

        # stabilization loss
        # value (cls_v) and index (cls_i) of the max probability in the prediction
        l_cls_v, l_cls_i = torch.max(F.softmax(l_class_logit, dim=1), dim=1)
        r_cls_v, r_cls_i = torch.max(F.softmax(r_class_logit, dim=1), dim=1)
        le_cls_v, le_cls_i = torch.max(F.softmax(le_class_logit, dim=1), dim=1)
        re_cls_v, re_cls_i = torch.max(F.softmax(re_class_logit, dim=1), dim=1)

        l_cls_i = l_cls_i.data.cpu().numpy()
        r_cls_i = r_cls_i.data.cpu().numpy()
        le_cls_i = le_cls_i.data.cpu().numpy()
        re_cls_i = re_cls_i.data.cpu().numpy()

        # stable prediction mask
        l_mask = (l_cls_v > args.stable_threshold).data.cpu().numpy()
        r_mask = (r_cls_v > args.stable_threshold).data.cpu().numpy()
        le_mask = (le_cls_v > args.stable_threshold).data.cpu().numpy()
        re_mask = (re_cls_v > args.stable_threshold).data.cpu().numpy()

        # detach logit -> for generating stablilization target
        in_r_cons_logit = Variable(r_cons_logit.detach().data,
                                   requires_grad=False)
        tar_l_class_logit = Variable(l_class_logit.clone().detach().data,
                                     requires_grad=False)

        in_l_cons_logit = Variable(l_cons_logit.detach().data,
                                   requires_grad=False)
        tar_r_class_logit = Variable(r_class_logit.clone().detach().data,
                                     requires_grad=False)

        # generate target for each sample
        for sdx in range(0, minibatch_size):
            l_stable = False
            if l_mask[sdx] == 0 and le_mask[sdx] == 0:
                # unstable: do not satisfy 2nd condition
                tar_l_class_logit[sdx, ...] = in_r_cons_logit[sdx, ...]
            elif l_cls_i[sdx] != le_cls_i[sdx]:
                # unstable: do not satisfy 1st condition
                tar_l_class_logit[sdx, ...] = in_r_cons_logit[sdx, ...]
            else:
                l_stable = True

            r_stable = False
            if r_mask[sdx] == 0 and re_mask[sdx] == 0:
                # unstable: do not satisfy 2nd condition
                tar_r_class_logit[sdx, ...] = in_l_cons_logit[sdx, ...]
            elif r_cls_i[sdx] != re_cls_i[sdx]:
                # unstable: do not satisfy 1st condition
                tar_r_class_logit[sdx, ...] = in_l_cons_logit[sdx, ...]
            else:
                r_stable = True

            # calculate stanility if both models are stable for a sample
            if l_stable and r_stable:
                # compare by consistency
                l_sample_cons = consistency_criterion(
                    l_cons_logit[sdx:sdx + 1, ...], le_class_logit[sdx:sdx + 1,
                                                                   ...])
                r_sample_cons = consistency_criterion(
                    r_cons_logit[sdx:sdx + 1, ...], re_class_logit[sdx:sdx + 1,
                                                                   ...])
                if l_sample_cons.data.cpu().numpy(
                )[0] < r_sample_cons.data.cpu().numpy()[0]:
                    # loss: l -> r
                    tar_r_class_logit[sdx, ...] = in_l_cons_logit[sdx, ...]
                elif l_sample_cons.data.cpu().numpy(
                )[0] > r_sample_cons.data.cpu().numpy()[0]:
                    # loss: r -> l
                    tar_l_class_logit[sdx, ...] = in_r_cons_logit[sdx, ...]

        # calculate stablization weight
        stabilization_weight = args.stabilization_scale * ramps.sigmoid_rampup(
            epoch, args.stabilization_rampup)
        if not args.exclude_unlabeled:
            stabilization_weight = (unlabeled_minibatch_size /
                                    minibatch_size) * stabilization_weight

        # stabilization loss for r model
        if args.exclude_unlabeled:
            r_stabilization_loss = stabilization_weight * stabilization_criterion(
                r_cons_logit, tar_l_class_logit) / minibatch_size
        else:
            for idx in range(unlabeled_minibatch_size, minibatch_size):
                tar_l_class_logit[idx, ...] = in_r_cons_logit[idx, ...]
            r_stabilization_loss = stabilization_weight * stabilization_criterion(
                r_cons_logit, tar_l_class_logit) / unlabeled_minibatch_size
        meters.update('r_stable_loss', r_stabilization_loss.data[0])
        r_loss += r_stabilization_loss

        # stabilization loss for l model
        if args.exclude_unlabeled:
            l_stabilization_loss = stabilization_weight * stabilization_criterion(
                l_cons_logit, tar_r_class_logit) / minibatch_size
        else:
            for idx in range(unlabeled_minibatch_size, minibatch_size):
                tar_r_class_logit[idx, ...] = in_l_cons_logit[idx, ...]
            l_stabilization_loss = stabilization_weight * stabilization_criterion(
                l_cons_logit, tar_r_class_logit) / unlabeled_minibatch_size

        meters.update('l_stable_loss', l_stabilization_loss.data[0])
        l_loss += l_stabilization_loss

        if np.isnan(l_loss.data[0]) or np.isnan(r_loss.data[0]):
            LOG.info('Loss value equals to NAN!')
            continue
        assert not (l_loss.data[0] > 1e5), 'L-Loss explosion: {}'.format(
            l_loss.data[0])
        assert not (r_loss.data[0] > 1e5), 'R-Loss explosion: {}'.format(
            r_loss.data[0])
        meters.update('l_loss', l_loss.data[0])
        meters.update('r_loss', r_loss.data[0])

        # calculate prec and error
        l_prec = mt_func.accuracy(l_class_logit.data,
                                  target_var.data,
                                  topk=(1, ))[0]
        r_prec = mt_func.accuracy(r_class_logit.data,
                                  target_var.data,
                                  topk=(1, ))[0]

        meters.update('l_top1', l_prec[0], labeled_minibatch_size)
        meters.update('l_error1', 100. - l_prec[0], labeled_minibatch_size)

        meters.update('r_top1', r_prec[0], labeled_minibatch_size)
        meters.update('r_error1', 100. - r_prec[0], labeled_minibatch_size)

        # update model
        l_optimizer.zero_grad()
        l_loss.backward()
        l_optimizer.step()

        r_optimizer.zero_grad()
        r_loss.backward()
        r_optimizer.step()

        # record
        global_step += 1
        meters.update('batch_time', time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            LOG.info('Epoch: [{0}][{1}/{2}]\t'
                     'Batch-T {meters[batch_time]:.3f}\t'
                     'L-Class {meters[l_class_loss]:.4f}\t'
                     'R-Class {meters[r_class_loss]:.4f}\t'
                     'L-Res {meters[l_res_loss]:.4f}\t'
                     'R-Res {meters[r_res_loss]:.4f}\t'
                     'L-Cons {meters[l_cons_loss]:.4f}\t'
                     'R-Cons {meters[r_cons_loss]:.4f}\n'
                     'L-Stable {meters[l_stable_loss]:.4f}\t'
                     'R-Stable {meters[r_stable_loss]:.4f}\t'
                     'L-Prec@1 {meters[l_top1]:.3f}\t'
                     'R-Prec@1 {meters[r_top1]:.3f}\t'.format(
                         epoch, i, len(train_loader), meters=meters))

            log.record(
                epoch + i / len(train_loader), {
                    'step': global_step,
                    **meters.values(),
                    **meters.averages(),
                    **meters.sums()
                })