def train(train_loader,
          model,
          optimizer,
          epoch,
          weak_mask=None,
          strong_mask=None):
    class_criterion = nn.BCELoss()
    [class_criterion] = to_cuda_if_available([class_criterion])

    meters = AverageMeterSet()
    meters.update('lr', optimizer.param_groups[0]['lr'])

    LOG.debug("Nb batches: {}".format(len(train_loader)))
    start = time.time()
    for i, (batch_input, target) in enumerate(train_loader):
        [batch_input, target] = to_cuda_if_available([batch_input, target])
        LOG.debug(batch_input.mean())

        strong_pred, weak_pred = model(batch_input)
        loss = 0
        if weak_mask is not None:
            # Weak BCE Loss
            # Trick to not take unlabeled data
            # Todo figure out another way
            target_weak = target.max(-2)[0]
            weak_class_loss = class_criterion(weak_pred[weak_mask],
                                              target_weak[weak_mask])
            if i == 1:
                LOG.debug("target: {}".format(target.mean(-2)))
                LOG.debug("Target_weak: {}".format(target_weak))
                LOG.debug(weak_class_loss)
            meters.update('Weak loss', weak_class_loss.item())

            loss += weak_class_loss

        if strong_mask is not None:
            # Strong BCE loss
            strong_class_loss = class_criterion(strong_pred[strong_mask],
                                                target[strong_mask])
            meters.update('Strong loss', strong_class_loss.item())

            loss += strong_class_loss

        assert not (np.isnan(loss.item())
                    or loss.item() > 1e5), 'Loss explosion: {}'.format(
                        loss.item())
        assert not loss.item() < 0, 'Loss problem, cannot be negative'
        meters.update('Loss', loss.item())

        # compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_time = time.time() - start

    LOG.info('Epoch: {}\t'
             'Time {:.2f}\t'
             '{meters}'.format(epoch, epoch_time, meters=meters))
Exemplo n.º 2
0
    def train(self, get_loss, get_acc, model_file, pretrain_file):

        if self.cfg.uda_mode or self.cfg.mixmatch_mode:
            ssl_mode = True
        else:
            ssl_mode = False
        """ train uda"""

        # tensorboardX logging
        if self.cfg.results_dir:
            dir = os.path.join('results', self.cfg.results_dir)
            if os.path.exists(dir) and os.path.isdir(dir):
                shutil.rmtree(dir)

            writer = SummaryWriter(log_dir=dir)

            #logger_path = dir + 'log.txt'
            #logger = Logger(logger_path, title='uda')
            #if self.cfg.no_unsup_loss:
            #    logger.set_names(['Train Loss', 'Valid Acc', 'Valid Loss', 'LR'])
            #else:
            #    logger.set_names(['Train Loss', 'Train Loss X', 'Train Loss U', 'Train Loss W U', 'Valid Acc', 'Valid Loss', 'LR'])
            
        meters = AverageMeterSet()

        self.model.train()

        if self.cfg.model == "custom":
            self.load(model_file, pretrain_file)    # between model_file and pretrain_file, only one model will be loaded

        model = self.model.to(self.device)
        ema_model = self.ema_model.to(self.device) if self.ema_model else None

        if self.cfg.data_parallel:                       # Parallel GPU mode
            model = nn.DataParallel(model)
            ema_model = nn.DataParallel(ema_model) if ema_model else None

        global_step = 0
        loss_sum = 0.
        max_acc = [0., 0, 0., 0.]   # acc, step, val_loss, train_loss
        no_improvement = 0

        sup_batch_size = None
        unsup_batch_size = None

        # Progress bar is set by unsup or sup data
        # uda_mode == True --> sup_iter is repeated
        # uda_mode == False --> sup_iter is not repeated
        iter_bar = tqdm(self.unsup_iter, total=self.cfg.total_steps, disable=self.cfg.hide_tqdm) if ssl_mode \
              else tqdm(self.sup_iter, total=self.cfg.total_steps, disable=self.cfg.hide_tqdm)

        start = time.time()

        for i, batch in enumerate(iter_bar):
            # Device assignment
            if ssl_mode:
                sup_batch = [t.to(self.device) for t in next(self.sup_iter)]
                unsup_batch = [t.to(self.device) for t in batch]

                unsup_batch_size = unsup_batch_size or unsup_batch[0].shape[0]

                if unsup_batch[0].shape[0] != unsup_batch_size:
                    continue
            else:
                sup_batch = [t.to(self.device) for t in batch]
                unsup_batch = None

            # update
            self.optimizer.zero_grad()
            final_loss, sup_loss, unsup_loss, weighted_unsup_loss = get_loss(model, sup_batch, unsup_batch, global_step)

            if self.cfg.no_sup_loss:
                final_loss = unsup_loss
            elif self.cfg.no_unsup_loss:
                final_loss = sup_loss

            meters.update('train_loss', final_loss.item())
            meters.update('sup_loss', sup_loss.item())
            meters.update('unsup_loss', unsup_loss.item())
            meters.update('w_unsup_loss', weighted_unsup_loss.item())
            meters.update('lr', self.optimizer.get_lr()[0])

            final_loss.backward()
            self.optimizer.step()

            if self.ema_optimizer:
                self.ema_optimizer.step()

            # print loss
            global_step += 1
            loss_sum += final_loss.item()
            if not self.cfg.hide_tqdm:
                if ssl_mode:
                    iter_bar.set_description('final=%5.3f unsup=%5.3f sup=%5.3f'\
                            % (final_loss.item(), unsup_loss.item(), sup_loss.item()))
                else:
                    iter_bar.set_description('loss=%5.3f' % (final_loss.item()))

            if global_step % self.cfg.save_steps == 0:
                self.save(global_step)

            if get_acc and global_step % self.cfg.check_steps == 0 and global_step > self.cfg.check_after:
                if self.cfg.mixmatch_mode:
                    results = self.eval(get_acc, None, ema_model)
                else:
                    total_accuracy, avg_val_loss = self.validate()

                # logging
                writer.add_scalars('data/eval_acc', {'eval_acc' : total_accuracy}, global_step)
                writer.add_scalars('data/eval_loss', {'eval_loss': avg_val_loss}, global_step)

                if self.cfg.no_unsup_loss:
                    writer.add_scalars('data/train_loss', {'train_loss': meters['train_loss'].avg}, global_step)
                    writer.add_scalars('data/lr', {'lr': meters['lr'].avg}, global_step)
                else:
                    writer.add_scalars('data/train_loss', {'train_loss': meters['train_loss'].avg}, global_step)
                    writer.add_scalars('data/sup_loss', {'sup_loss': meters['sup_loss'].avg}, global_step)
                    writer.add_scalars('data/unsup_loss', {'unsup_loss': meters['unsup_loss'].avg}, global_step)
                    writer.add_scalars('data/w_unsup_loss', {'w_unsup_loss': meters['w_unsup_loss'].avg}, global_step)
                    writer.add_scalars('data/lr', {'lr': meters['lr'].avg}, global_step)

                meters.reset()

                if max_acc[0] < total_accuracy:
                    self.save(global_step)
                    max_acc = total_accuracy, global_step, avg_val_loss, final_loss.item()
                    no_improvement = 0
                else:
                    no_improvement += 1

                print("  Top 1 Accuracy: {0:.4f}".format(total_accuracy))
                print("  Validation Loss: {0:.4f}".format(avg_val_loss))
                print("  Train Loss: {0:.4f}".format(final_loss.item()))
                if ssl_mode:
                    print("  Sup Loss: {0:.4f}".format(sup_loss.item()))
                    print("  Unsup Loss: {0:.4f}".format(unsup_loss.item()))
                print("  Learning rate: {0:.7f}".format(self.optimizer.get_lr()[0]))

                print(
                    'Max Accuracy : %5.3f Best Val Loss : %5.3f Best Train Loss : %5.4f Max global_steps : %d Cur global_steps : %d' 
                    %(max_acc[0], max_acc[2], max_acc[3], max_acc[1], global_step), end='\n\n'
                )
                
                if no_improvement == self.cfg.early_stopping:
                    print("Early stopped")
                    total_time = time.time() - start
                    print('Total Training Time: %d' %(total_time), end='\n')
                    break


            if self.cfg.total_steps and self.cfg.total_steps < global_step:
                print('The total steps have been reached')
                total_time = time.time() - start
                print('Total Training Time: %d' %(total_time), end='\n') 
                if get_acc:
                    if self.cfg.mixmatch_mode:
                        results = self.eval(get_acc, None, ema_model)
                    else:
                        total_accuracy, avg_val_loss = self.validate()
                    if max_acc[0] < total_accuracy:
                        max_acc = total_accuracy, global_step, avg_val_loss, final_loss.item()             
                    print("  Top 1 Accuracy: {0:.4f}".format(total_accuracy))
                    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
                    print("  Train Loss: {0:.2f}".format(final_loss.item()))
                    print('Max Accuracy : %5.3f Best Val Loss :  %5.3f Best Train Loss :  %5.3f Max global_steps : %d Cur global_steps : %d' %(max_acc[0], max_acc[2], max_acc[3], max_acc[1], global_step), end='\n\n')
                self.save(global_step)
                return
        writer.close()
        return global_step
Exemplo n.º 3
0
def train(cfg,
          train_loader,
          model,
          optimizer,
          epoch,
          ema_model=None,
          weak_mask=None,
          strong_mask=None):
    """ One epoch of a Mean Teacher model
    :param train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch.
    Should return 3 values: teacher input, student input, labels
    :param model: torch.Module, model to be trained, should return a weak and strong prediction
    :param optimizer: torch.Module, optimizer used to train the model
    :param epoch: int, the current epoch of training
    :param ema_model: torch.Module, student model, should return a weak and strong prediction
    :param weak_mask: mask the batch to get only the weak labeled data (used to calculate the loss)
    :param strong_mask: mask the batch to get only the strong labeled data (used to calcultate the loss)
    """
    class_criterion = nn.BCELoss()
    consistency_criterion_strong = nn.MSELoss()
    lds_criterion = LDSLoss(xi=cfg.vat_xi,
                            eps=cfg.vat_eps,
                            n_power_iter=cfg.vat_n_power_iter)
    [class_criterion, consistency_criterion_strong,
     lds_criterion] = to_cuda_if_available(
         [class_criterion, consistency_criterion_strong, lds_criterion])

    meters = AverageMeterSet()

    LOG.debug("Nb batches: {}".format(len(train_loader)))
    start = time.time()
    rampup_length = len(train_loader) * cfg.n_epoch // 2
    for i, (batch_input, ema_batch_input, target) in enumerate(train_loader):
        global_step = epoch * len(train_loader) + i
        if global_step < rampup_length:
            rampup_value = ramps.sigmoid_rampup(global_step, rampup_length)
        else:
            rampup_value = 1.0

        # Todo check if this improves the performance
        # adjust_learning_rate(optimizer, rampup_value, rampdown_value)
        meters.update('lr', optimizer.param_groups[0]['lr'])

        [batch_input, ema_batch_input,
         target] = to_cuda_if_available([batch_input, ema_batch_input, target])
        LOG.debug(batch_input.mean())
        # Outputs
        strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input)
        strong_pred_ema = strong_pred_ema.detach()
        weak_pred_ema = weak_pred_ema.detach()

        strong_pred, weak_pred = model(batch_input)
        loss = None
        # Weak BCE Loss
        # Take the max in axis 2 (assumed to be time)
        if len(target.shape) > 2:
            target_weak = target.max(-2)[0]
        else:
            target_weak = target

        if weak_mask is not None:
            weak_class_loss = class_criterion(weak_pred[weak_mask],
                                              target_weak[weak_mask])
            ema_class_loss = class_criterion(weak_pred_ema[weak_mask],
                                             target_weak[weak_mask])

            if i == 0:
                LOG.debug("target: {}".format(target.mean(-2)))
                LOG.debug("Target_weak: {}".format(target_weak))
                LOG.debug("Target_weak mask: {}".format(
                    target_weak[weak_mask]))
                LOG.debug(weak_class_loss)
                LOG.debug("rampup_value: {}".format(rampup_value))
            meters.update('weak_class_loss', weak_class_loss.item())

            meters.update('Weak EMA loss', ema_class_loss.item())

            loss = weak_class_loss

        # Strong BCE loss
        if strong_mask is not None:
            strong_class_loss = class_criterion(strong_pred[strong_mask],
                                                target[strong_mask])
            meters.update('Strong loss', strong_class_loss.item())

            strong_ema_class_loss = class_criterion(
                strong_pred_ema[strong_mask], target[strong_mask])
            meters.update('Strong EMA loss', strong_ema_class_loss.item())
            if loss is not None:
                loss += strong_class_loss
            else:
                loss = strong_class_loss

        # Teacher-student consistency cost
        if ema_model is not None:

            consistency_cost = cfg.max_consistency_cost * rampup_value
            meters.update('Consistency weight', consistency_cost)
            # Take only the consistence with weak and unlabel
            consistency_loss_strong = consistency_cost * consistency_criterion_strong(
                strong_pred, strong_pred_ema)
            meters.update('Consistency strong', consistency_loss_strong.item())
            if loss is not None:
                loss += consistency_loss_strong
            else:
                loss = consistency_loss_strong

            meters.update('Consistency weight', consistency_cost)
            # Take only the consistence with weak and unlabel
            consistency_loss_weak = consistency_cost * consistency_criterion_strong(
                weak_pred, weak_pred_ema)
            meters.update('Consistency weak', consistency_loss_weak.item())
            if loss is not None:
                loss += consistency_loss_weak
            else:
                loss = consistency_loss_weak

        # LDS loss
        if cfg.vat_enabled:
            lds_loss = cfg.vat_coeff * lds_criterion(model, batch_input,
                                                     weak_pred)
            LOG.info('loss: {:.3f}, lds loss: {:.3f}'.format(
                loss,
                cfg.vat_coeff * lds_loss.detach().cpu().numpy()))
            loss += lds_loss
        else:
            if i % 25 == 0:
                LOG.info('loss: {:.3f}'.format(loss))

        assert not (np.isnan(loss.item())
                    or loss.item() > 1e5), 'Loss explosion: {}'.format(
                        loss.item())
        assert not loss.item() < 0, 'Loss problem, cannot be negative'
        meters.update('Loss', loss.item())

        # compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        if ema_model is not None:
            update_ema_variables(model, ema_model, 0.999, global_step)

    epoch_time = time.time() - start

    LOG.info('Epoch: {}\t'
             'Time {:.2f}\t'
             '{meters}'.format(epoch, epoch_time, meters=meters))
def train(train_loader,
          model,
          optimizer,
          epoch,
          ema_model=None,
          weak_mask=None,
          strong_mask=None):
    """ One epoch of a Mean Teacher model
    :param train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch.
    Should return 3 values: teacher input, student input, labels
    :param model: torch.Module, model to be trained, should return a weak and strong prediction
    :param optimizer: torch.Module, optimizer used to train the model
    :param epoch: int, the current epoch of training
    :param ema_model: torch.Module, student model, should return a weak and strong prediction
    :param weak_mask: mask the batch to get only the weak labeled data (used to calculate the loss)
    :param strong_mask: mask the batch to get only the strong labeled data (used to calcultate the loss)
    """
    class_criterion = nn.BCELoss()

    ##################################################
    class_criterion1 = nn.BCELoss(reduction='none')
    ##################################################

    consistency_criterion = nn.MSELoss()

    # [class_criterion, consistency_criterion] = to_cuda_if_available(
    #     [class_criterion, consistency_criterion])
    [class_criterion, class_criterion1,
     consistency_criterion] = to_cuda_if_available(
         [class_criterion, class_criterion1, consistency_criterion])

    meters = AverageMeterSet()

    LOG.debug("Nb batches: {}".format(len(train_loader)))
    start = time.time()
    rampup_length = len(train_loader) * cfg.n_epoch // 2

    print("Train\n")
    # LOG.info("Weak[k] -> Weak[k]")
    # LOG.info("Weak[k] -> strong[k]")

    # print(weak_mask.start)
    # print(strong_mask.start)
    # exit()
    count = 0
    check_cus_weak = 0
    difficulty_loss = 0
    loss_w = 1
    LOG.info("loss paramater:{}".format(loss_w))
    for i, (batch_input, ema_batch_input, target) in enumerate(train_loader):
        # print(batch_input.shape)
        # print(ema_batch_input.shape)
        # exit()
        global_step = epoch * len(train_loader) + i
        if global_step < rampup_length:
            rampup_value = ramps.sigmoid_rampup(global_step, rampup_length)
        else:
            rampup_value = 1.0

        # Todo check if this improves the performance
        # adjust_learning_rate(optimizer, rampup_value, rampdown_value)
        meters.update('lr', optimizer.param_groups[0]['lr'])

        [batch_input, ema_batch_input,
         target] = to_cuda_if_available([batch_input, ema_batch_input, target])
        LOG.debug("batch_input:{}".format(batch_input.mean()))

        # print(batch_input)
        # exit()

        # Outputs
        ##################################################
        # strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input)
        strong_pred_ema, weak_pred_ema, sof_ema = ema_model(ema_batch_input)
        sof_ema = sof_ema.detach()
        ##################################################

        strong_pred_ema = strong_pred_ema.detach()
        weak_pred_ema = weak_pred_ema.detach()

        ##################################################
        # strong_pred, weak_pred = model(batch_input)
        strong_pred, weak_pred, sof = model(batch_input)
        ##################################################

        ##################################################
        # custom_ema_loss = Custom_BCE_Loss(ema_batch_input, class_criterion1)

        if difficulty_loss == 0:
            LOG.info("############### Deffine Difficulty Loss ###############")
            difficulty_loss = 1
        custom_ema_loss = Custom_BCE_Loss_difficulty(ema_batch_input,
                                                     class_criterion1,
                                                     paramater=loss_w)
        custom_ema_loss.initialize(strong_pred_ema, sof_ema)

        # custom_loss = Custom_BCE_Loss(batch_input, class_criterion1)
        custom_loss = Custom_BCE_Loss_difficulty(batch_input,
                                                 class_criterion1,
                                                 paramater=loss_w)
        custom_loss.initialize(strong_pred, sof)
        ##################################################

        # print(strong_pred.shape)
        # print(strong_pred)
        # print(weak_pred.shape)
        # print(weak_pred)
        # exit()

        loss = None
        # Weak BCE Loss
        # Take the max in the time axis
        # torch.set_printoptions(threshold=10000)
        # print(target[-10])
        # # print(target.max(-2))
        # # print(target.max(-2)[0])
        # print(target.max(-1)[0][-10])
        # exit()

        target_weak = target.max(-2)[0]
        if weak_mask is not None:
            weak_class_loss = class_criterion(weak_pred[weak_mask],
                                              target_weak[weak_mask])
            ema_class_loss = class_criterion(weak_pred_ema[weak_mask],
                                             target_weak[weak_mask])

            print(
                "noraml_weak:",
                class_criterion(weak_pred[weak_mask], target_weak[weak_mask]))

            ##################################################
            custom_weak_class_loss = custom_loss.weak(target_weak, weak_mask)
            custom_ema_class_loss = custom_ema_loss.weak(
                target_weak, weak_mask)
            print("custom_weak:", custom_weak_class_loss)
            ##################################################

            count += 1
            check_cus_weak += custom_weak_class_loss
            # print(custom_weak_class_loss.item())

            if i == 0:
                LOG.debug("target: {}".format(target.mean(-2)))
                LOG.debug("Target_weak: {}".format(target_weak))
                LOG.debug("Target_weak mask: {}".format(
                    target_weak[weak_mask]))
                LOG.debug(custom_weak_class_loss)  ###
                LOG.debug("rampup_value: {}".format(rampup_value))
            meters.update('weak_class_loss',
                          custom_weak_class_loss.item())  ###
            meters.update('Weak EMA loss', custom_ema_class_loss.item())  ###

            # loss = weak_class_loss
            loss = custom_weak_class_loss

            ####################################################################################
            # weak_class_loss = class_criterion(strong_pred[weak_mask], target[weak_mask])
            # ema_class_loss = class_criterion(strong_pred_ema[weak_mask], target[weak_mask])
            # # if i == 0:
            # #     LOG.debug("target: {}".format(target.mean(-2)))
            # #     LOG.debug("Target_weak: {}".format(target))
            # #     LOG.debug("Target_weak mask: {}".format(target[weak_mask]))
            # #     LOG.debug(weak_class_loss)
            # #     LOG.debug("rampup_value: {}".format(rampup_value))
            # meters.update('weak_class_loss', weak_class_loss.item())
            # meters.update('Weak EMA loss', ema_class_loss.item())

            # loss = weak_class_loss
            ####################################################################################

        # Strong BCE loss
        if strong_mask is not None:
            strong_class_loss = class_criterion(strong_pred[strong_mask],
                                                target[strong_mask])
            # meters.update('Strong loss', strong_class_loss.item())

            strong_ema_class_loss = class_criterion(
                strong_pred_ema[strong_mask], target[strong_mask])
            # meters.update('Strong EMA loss', strong_ema_class_loss.item())

            print(
                "normal_strong:",
                class_criterion(strong_pred[strong_mask], target[strong_mask]))

            ##################################################
            custom_strong_class_loss = custom_loss.strong(target, strong_mask)
            meters.update('Strong loss', custom_strong_class_loss.item())

            custom_strong_ema_class_loss = custom_ema_loss.strong(
                target, strong_mask)
            meters.update('Strong EMA loss',
                          custom_strong_ema_class_loss.item())
            print("custom_strong:", custom_strong_class_loss)
            ##################################################

            if loss is not None:
                # loss += strong_class_loss
                loss += custom_strong_class_loss
            else:
                # loss = strong_class_loss
                loss = custom_strong_class_loss

        # print("check_weak:", class_criterion1(weak_pred[weak_mask], target_weak[weak_mask]).mean())
        # print("check_strong:", class_criterion1(strong_pred[strong_mask], target[strong_mask]).mean())
        # print("\n")

        # exit()

        # Teacher-student consistency cost
        if ema_model is not None:

            consistency_cost = cfg.max_consistency_cost * rampup_value
            meters.update('Consistency weight', consistency_cost)
            # Take consistency about strong predictions (all data)
            consistency_loss_strong = consistency_cost * consistency_criterion(
                strong_pred, strong_pred_ema)
            meters.update('Consistency strong', consistency_loss_strong.item())
            if loss is not None:
                loss += consistency_loss_strong
            else:
                loss = consistency_loss_strong

            meters.update('Consistency weight', consistency_cost)
            # Take consistency about weak predictions (all data)
            consistency_loss_weak = consistency_cost * consistency_criterion(
                weak_pred, weak_pred_ema)
            meters.update('Consistency weak', consistency_loss_weak.item())
            if loss is not None:
                loss += consistency_loss_weak
            else:
                loss = consistency_loss_weak

        assert not (np.isnan(loss.item())
                    or loss.item() > 1e5), 'Loss explosion: {}'.format(
                        loss.item())
        assert not loss.item() < 0, 'Loss problem, cannot be negative'
        meters.update('Loss', loss.item())

        # compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        if ema_model is not None:
            update_ema_variables(model, ema_model, 0.999, global_step)

    epoch_time = time.time() - start

    LOG.info('Epoch: {}\t'
             'Time {:.2f}\t'
             '{meters}'.format(epoch, epoch_time, meters=meters))

    print("\ncheck_cus_weak:\n", check_cus_weak / count)