def inference(model, itr, pop_idx, arch, val_loader, use_gpu, args):
    model.eval()
    all_prec1 = AverageMeter()
    all_prec5 = AverageMeter()

    with torch.no_grad():
        for batch_idx, (img, label) in enumerate(val_loader):
            if use_gpu:
                img, label = img.cuda(), label.cuda()

            output = model(img, arch)
            prec1, prec5 = accuracy(output, label, topk=(1, 5))

            if args.distributed:
                prec1 = reduce_tensor(prec1, args.world_size)
                prec5 = reduce_tensor(prec5, args.world_size)

            all_prec1.update(prec1.item(), img.size(0))
            all_prec5.update(prec5.item(), img.size(0))

        flops = get_flops(arch, args.flop_table) / 1e6
        if args.local_rank == 0:
            logging.info('Iter: [{}/{}][{}/{}]\t'
                         'Arch: {}\t'
                         'FLOPs: {:.2f}M\t'
                         'Prec@1: {:.2f}%\t'
                         'Prec@5: {:.2f}%'
                         .format(itr, args.total_search_iters, pop_idx + 1, args.pop_size, arch,
                                 flops, all_prec1.avg, all_prec5.avg))

    return all_prec1.avg, all_prec5.avg, flops
示例#2
0
 def sync(self):
     rank = dist.get_rank()
     world_size = dist.get_world_size()
     val = torch.tensor(self.val).cuda()
     sum_v = torch.tensor(self.sum).cuda()
     count = torch.tensor(self.count).cuda()
     self.val = reduce_tensor(val, world_size).item()
     self.sum = reduce_tensor(sum_v, 1).item()
     self.count = reduce_tensor(count, 1).item()
     self.avg = self.sum / max(1, self.count)
示例#3
0
    def _forward(self, x, logpx=None):
        num_channels = x.size(-1)
        used_mean = self.running_mean.clone().detach()
        used_var = self.running_var.clone().detach()

        if self.training:
            # compute batch statistics
            x_t = x.transpose(0, 1).reshape(num_channels, -1)
            batch_mean = torch.mean(x_t, dim=1)

            if self.sync:
                batch_ex2 = torch.mean(x_t**2, dim=1)
                batch_mean = reduce_tensor(batch_mean)
                batch_ex2 = reduce_tensor(batch_ex2)
                batch_var = batch_ex2 - batch_mean**2
            else:
                batch_var = torch.var(x_t, dim=1)

            # moving average
            if self.bn_lag > 0:
                used_mean = batch_mean - (1 - self.bn_lag) * (
                    batch_mean - used_mean.detach())
                used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
                used_var = batch_var - (1 - self.bn_lag) * (batch_var -
                                                            used_var.detach())
                used_var /= (1. - self.bn_lag**(self.step[0] + 1))

            # update running estimates
            self.running_mean -= self.decay * (self.running_mean -
                                               batch_mean.data)
            self.running_var -= self.decay * (self.running_var -
                                              batch_var.data)
            self.step += 1

        # perform normalization
        used_mean = used_mean.view(*self.shape).expand_as(x)
        used_var = used_var.view(*self.shape).expand_as(x)

        y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps))

        if self.affine:
            weight = self.weight.view(*self.shape).expand_as(x)
            bias = self.bias.view(*self.shape).expand_as(x)
            y = y * torch.exp(weight) + bias

        if logpx is None:
            return y
        else:
            return y, logpx - self._logdetgrad(x, used_var).sum(-1,
                                                                keepdim=True)
示例#4
0
def ens_validate(val_loader, model, criterion, args, log, num_mc_samples=20, suffix=''):
    model.eval()

    ece_func = _ECELoss().cuda(args.gpu)
    with torch.no_grad():
        targets = []
        mis = [0 for _ in range(len(val_loader))]
        preds = [0 for _ in range(len(val_loader))]
        rets = torch.zeros(num_mc_samples, 9).cuda(args.gpu)
        for i, (input, target) in enumerate(val_loader):
            input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
            targets.append(target)
            for ens in range(num_mc_samples):
                output = model(input)

                one_loss = criterion(output, target)
                one_prec1, one_prec5 = accuracy(output, target, topk=(1, 5))

                mis[i] = (mis[i] * ens + (-output.softmax(-1) *
                    output.log_softmax(-1)).sum(1)) / (ens + 1)
                preds[i] = (preds[i] * ens + output.softmax(-1)) / (ens + 1)

                loss = criterion(preds[i].log(), target)
                prec1, prec5 = accuracy(preds[i], target, topk=(1, 5))

                rets[ens, 0] += ens*target.size(0)
                rets[ens, 1] += one_loss.item()*target.size(0)
                rets[ens, 2] += one_prec1.item()*target.size(0)
                rets[ens, 3] += one_prec5.item()*target.size(0)
                rets[ens, 5] += loss.item()*target.size(0)
                rets[ens, 6] += prec1.item()*target.size(0)
                rets[ens, 7] += prec5.item()*target.size(0)

        preds = torch.cat(preds, 0)

        # to sync
        confidences, predictions = torch.max(preds, 1)
        targets = torch.cat(targets, 0)
        mis = (- preds * preds.log()).sum(1) - torch.cat(mis, 0)
        rets /= targets.size(0)

        if args.distributed:
            if suffix == '':
                confidences = dist_collect(confidences)
                predictions = dist_collect(predictions)
                targets = dist_collect(targets)
            mis = dist_collect(mis)
            rets = reduce_tensor(rets.data, args)
        rets = rets.data.cpu().numpy()
        if suffix == '':
            ens_ece = ece_func(confidences, predictions, targets,
                os.path.join(args.save_path, 'ens_cal{}.pdf'.format(suffix)))
            rets[-1, -1] = ens_ece

    if args.gpu == 0:
        np.save(os.path.join(args.save_path, 'mis{}.npy'.format(suffix)),
            mis.data.cpu().numpy())
    return rets
示例#5
0
def validate_single_class(config, data_loader, model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(images)

        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        acc1 = reduce_tensor(acc1)
        acc5 = reduce_tensor(acc5)
        loss = reduce_tensor(loss)

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f'Test: [{idx}/{len(data_loader)}]\t'
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                f'Mem {memory_used:.0f}MB')
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
示例#6
0
def train_epoch(model, epoch, optim_tools, train_loader, use_gpu):
    model.train()
    optimizer, criterion, scheduler = optim_tools

    losses, train_time, data_time = [AverageMeter() for _ in range(3)]
    st_time = time.time()
    for batch_idx, (img, label) in enumerate(train_loader):
        data_time.update(time.time() - st_time)
        if use_gpu:
            img, label = img.cuda(), label.cuda()

        arch, flops = uniform_constraint_sampling(sum(args.num_layer_list),
                                                  args.num_block_type,
                                                  args.flop_table,
                                                  args.local_rank)

        output = model(img, arch)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if not args.distributed:
            losses.update(loss.item(), img.size(0))

        if use_gpu:
            torch.cuda.synchronize()
        train_time.update(time.time() - st_time)

        if batch_idx == 0 or (batch_idx +
                              1) % args.disp_freq == 0 or batch_idx + 1 == len(
                                  train_loader):
            if args.distributed:
                reduced_loss = reduce_tensor(loss.detach(), args.world_size)
                losses.update(reduced_loss.item(), img.size(0))

            if args.local_rank == 0:
                lr = scheduler.get_lr()[0]
                logging.info(
                    'Epoch: [{}/{}][{}/{}]\t'
                    'LR: {:.2e}\t'
                    'Loss: {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Train time: {train_time.val:.4f}s ({train_time.avg:.4f}s)\t'
                    'Load data time: {data_time.val:.4f}s ({data_time.avg:.4f}s)'
                    .format(epoch,
                            args.total_epochs,
                            batch_idx + 1,
                            len(train_loader),
                            lr,
                            loss=losses,
                            train_time=train_time,
                            data_time=data_time))
        st_time = time.time()
示例#7
0
def validate(model, loader, use_gpu):
    model.eval()
    all_prec1, all_prec5, val_time = [AverageMeter() for _ in range(3)]

    st_time = time.time()
    with torch.no_grad():
        for batch_idx, (img, label) in enumerate(loader):
            if use_gpu:
                img, label = img.cuda(), label.cuda()

            output = model(img)
            prec1, prec5 = accuracy(output, label, topk=(1, 5))

            if args.distributed:
                prec1 = reduce_tensor(prec1, args.world_size)
                prec5 = reduce_tensor(prec5, args.world_size)

            all_prec1.update(prec1.item(), img.size(0))
            all_prec5.update(prec5.item(), img.size(0))

            if use_gpu:
                torch.cuda.synchronize()
            val_time.update(time.time() - st_time)

            if args.local_rank == 0 and \
                    (batch_idx == 0 or (batch_idx + 1) % args.disp_freq == 0 or batch_idx + 1 == len(loader)):
                logging.info('Iter: [{}/{}]\t'
                             'Val time: {:.4f}s\t'
                             'Prec@1: {:.2f}%\t'
                             'Prec@5: {:.2f}%'.format(batch_idx + 1,
                                                      len(loader),
                                                      val_time.avg,
                                                      all_prec1.avg,
                                                      all_prec5.avg))
            st_time = time.time()

    return all_prec1.avg, all_prec5.avg
示例#8
0
    def sync(self):
        buf = torch.tensor([self._sum, self._count],
                           dtype=torch.float32).cuda()
        buf = reduce_tensor(buf, 1)
        _sum, _count = buf.tolist()
        _avg = _sum / max(1, _count)
        r = self._history_count / max(1, self._history_count + _count)

        self._history_avg = r * self._history_avg + (1.0 - r) * _avg
        self._history_count += _count

        self._sum = 0
        self._count = 0

        self._avg = None
示例#9
0
    def _step(self, inputs, input_sizes, targets):
        """
        Make a single gradient update. This is called by train() and should not
        be called manually.

        Parameters
        ----------

        inputs:
        inputs_sizes:
        targets:

        """

        output = self.model(inputs, input_sizes)

        loss = self.criterion(output, targets.long())

        loss = loss / inputs.size(0)  # average the loss by minibatch

        if self.distributed:
            loss = loss.to(self.device)
            loss_value = reduce_tensor(loss, self.world_size).item()
        else:
            loss_value = loss.item()

        # Check to ensure valid loss was calculated
        valid_loss, error = check_loss(loss, loss_value)

        if valid_loss:

            self.optimizer.zero_grad()

            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()

            torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer),
                                           self.max_norm)

            self.optimizer.step()
        else:
            print(error)
            print('Skipping grad update')
            loss_value = 0

        return output, loss_value
示例#10
0
    def step(self, loss):
        if self.distributed:
            loss = loss.to(self.device)
            loss_value = reduce_tensor(loss, self.world_size).item()
        else:
            loss_value = loss.item()

        valid_loss, error = check_loss(loss, loss_value)
        if valid_loss:
            self.optimizer.zero_grad()
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.max_norm)
            self.optimizer.step()
        else:
            print(error)
            print('Skipping grad update')
            return False

        self.avg_loss += loss_value
        return True
示例#11
0
def predict(model, path):
    with open(os.path.join(Path('data'), "dev.pkl"), "rb") as fin:
        x_dev, y_dev = pickle.load(fin)
    dev_examples = predict_processor.get_test_examples(path, x_dev, x_dev, size=-1)
    # print("测试数据量:{}".format(len(dev_examples)))
    # print("device:{}".format(device))
    test_features = convert_examples_to_features(
        dev_examples, label_list, args.max_seq_length, tokenizer)

    logger.info("***** Running prediction *****")
    logger.info("  Num examples = %d", len(dev_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)

    all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_ids for f in test_features], dtype=torch.long)
    test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    # Run prediction for full data
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size)

    model.eval()
    pre_loss, pre_accuracy = 0, 0
    nb_pre_steps, nb_pre_examples = 0, 0
    for step, batch in enumerate(tqdm(test_dataloader, desc="Prediction Iteration")):
        input_ids, input_mask, segment_ids, label_ids = batch
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)
        with torch.no_grad():
            logits = model(input_ids, segment_ids, input_mask)
        loss_fct = CrossEntropyLoss().to(device)
        tmp_pre_loss = loss_fct(logits.view(-1, num_labels), label_ids.squeeze())

        tmp_pre_accuracy = accuracy(logits.view(-1, num_labels).detach().cpu().numpy(),
                                     label_ids.squeeze().detach().cpu().numpy())

        if args.local_rank != -1:
            tmp_pre_loss = reduce_tensor(tmp_pre_loss)
            tmp_pre_accuracy = reduce_tensor(torch.tensor(tmp_pre_accuracy).to(device))

        pre_loss += tmp_pre_loss.mean().item()
        pre_accuracy += tmp_pre_accuracy.item()
        nb_pre_examples += input_ids.size(0)
        nb_pre_steps += 1

    pre_loss = pre_loss / nb_pre_steps
    pre_accuracy = pre_accuracy / nb_pre_examples

    result = {'pre_loss': pre_loss, 'pre_accuracy': pre_accuracy}

    output_pre_file = os.path.join(args.output_dir, "pre_results.txt")
    with open(output_pre_file, "w") as writer:
        logger.info("***** Pre results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return pre_loss, pre_accuracy
示例#12
0
            input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
            # measure data loading time
            data_time.update(time.time() - end)
            inputs = inputs.to(device)

            out, output_sizes = model(inputs, input_sizes)
            out = out.transpose(0, 1)  # TxNxH

            float_out = out.float()  # ensure float32 for loss
            loss = criterion(float_out, targets, output_sizes,
                             target_sizes).to(device)
            loss = loss / inputs.size(0)  # average the loss by minibatch

            if args.distributed:
                loss = loss.to(device)
                loss_value = reduce_tensor(loss, args.world_size).item()
                data_time = reduce_tensor(data_time,
                                          args.world_size,
                                          reduce_op_max=True)
            else:
                loss_value = loss.item()

            # Check to ensure valid loss was calculated
            valid_loss, error = check_loss(loss, loss_value)
            if valid_loss:
                optimizer.zero_grad()
                # compute gradient
                if args.mixed_precision:
                    optimizer.backward(loss)
                    optimizer.clip_master_grads(args.max_norm)
                else:
示例#13
0
    def train(self):
        data_iter = iter(self.train_dataloader)

        if self.train_config.resume_checkpoint:
            start = self.resume_step + 1
        else:
            start = 0

        moving_max_grad = 0
        moving_grad_moment = 0.999
        max_grad = 0

        for step in range(start, self.train_config.total_step + 1):
            try:
                image_dict = next(data_iter)
            except:
                data_iter = iter(self.train_dataloader)
                image_dict = next(data_iter)

            image, alpha, trimap, mask = image_dict['image'], image_dict[
                'alpha'], image_dict['trimap'], image_dict['mask']
            image = image.cuda()
            alpha = alpha.cuda()
            trimap = trimap.cuda()
            mask = mask.cuda()
            fg_norm, bg_norm = image_dict['fg'].cuda(), image_dict['bg'].cuda()
            # train() of DistributedDataParallel has no return
            self.G.train()
            log_info = ""
            loss = 0
            """===== Update Learning Rate ====="""
            if step < self.train_config.warmup_step and self.train_config.resume_checkpoint is None:
                cur_G_lr = utils.warmup_lr(self.train_config.G_lr, step + 1,
                                           self.train_config.warmup_step)
                utils.update_lr(cur_G_lr, self.G_optimizer)

            else:
                self.G_scheduler.step()
                cur_G_lr = self.G_scheduler.get_lr()[0]
            """===== Forward G ====="""

            pred = self.G(image, mask)
            alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred[
                'alpha_os1'], pred['alpha_os4'], pred['alpha_os8']

            weight_os8 = utils.get_unknown_tensor(trimap)
            weight_os8[...] = 1

            flag = False
            if step < self.train_config.warmup_step:
                flag = True
                weight_os4 = utils.get_unknown_tensor(trimap)
                weight_os1 = utils.get_unknown_tensor(trimap)
            elif step < self.train_config.warmup_step * 3:
                if random.randint(0, 1) == 0:
                    flag = True
                    weight_os4 = utils.get_unknown_tensor(trimap)
                    weight_os1 = utils.get_unknown_tensor(trimap)
                else:
                    weight_os4 = utils.get_unknown_tensor_from_pred(
                        alpha_pred_os8,
                        rand_width=CONFIG.model.self_refine_width1,
                        train_mode=True)
                    alpha_pred_os4[weight_os4 == 0] = alpha_pred_os8[weight_os4
                                                                     == 0]
                    weight_os1 = utils.get_unknown_tensor_from_pred(
                        alpha_pred_os4,
                        rand_width=CONFIG.model.self_refine_width2,
                        train_mode=True)
                    alpha_pred_os1[weight_os1 == 0] = alpha_pred_os4[weight_os1
                                                                     == 0]
            else:
                weight_os4 = utils.get_unknown_tensor_from_pred(
                    alpha_pred_os8,
                    rand_width=CONFIG.model.self_refine_width1,
                    train_mode=True)
                alpha_pred_os4[weight_os4 == 0] = alpha_pred_os8[weight_os4 ==
                                                                 0]
                weight_os1 = utils.get_unknown_tensor_from_pred(
                    alpha_pred_os4,
                    rand_width=CONFIG.model.self_refine_width2,
                    train_mode=True)
                alpha_pred_os1[weight_os1 == 0] = alpha_pred_os4[weight_os1 ==
                                                                 0]
            """===== Calculate Loss ====="""
            if self.train_config.rec_weight > 0:
                self.loss_dict['rec'] = (self.regression_loss(alpha_pred_os1, alpha, loss_type='l1', weight=weight_os1) * 2 +\
                 self.regression_loss(alpha_pred_os4, alpha, loss_type='l1', weight=weight_os4) * 1 +\
                  self.regression_loss(alpha_pred_os8, alpha, loss_type='l1', weight=weight_os8) * 1) / 5.0 * self.train_config.rec_weight

            if self.train_config.comp_weight > 0:
                self.loss_dict['comp'] = (self.composition_loss(alpha_pred_os1, fg_norm, bg_norm, image, weight=weight_os1) * 2 +\
                 self.composition_loss(alpha_pred_os4, fg_norm, bg_norm, image, weight=weight_os4) * 1 +\
                  self.composition_loss(alpha_pred_os8, fg_norm, bg_norm, image, weight=weight_os8) * 1) / 5.0 * self.train_config.comp_weight

            if self.train_config.lap_weight > 0:
                self.loss_dict['lap'] = (self.lap_loss(logit=alpha_pred_os1, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os1) * 2 +\
                 self.lap_loss(logit=alpha_pred_os4, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os4) * 1 +\
                  self.lap_loss(logit=alpha_pred_os8, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os8) * 1) / 5.0 * self.train_config.lap_weight

            for loss_key in self.loss_dict.keys():
                if self.loss_dict[loss_key] is not None and loss_key in [
                        'rec', 'comp', 'lap'
                ]:
                    loss += self.loss_dict[loss_key]
            """===== Back Propagate ====="""
            self.reset_grad()

            loss.backward()
            """===== Clip Large Gradient ====="""
            if self.train_config.clip_grad:
                if moving_max_grad == 0:
                    moving_max_grad = nn_utils.clip_grad_norm_(
                        self.G.parameters(), 1e+6)
                    max_grad = moving_max_grad
                else:
                    max_grad = nn_utils.clip_grad_norm_(
                        self.G.parameters(), 2 * moving_max_grad)
                    moving_max_grad = moving_max_grad * moving_grad_moment + max_grad * (
                        1 - moving_grad_moment)
            """===== Update Parameters ====="""
            self.G_optimizer.step()
            """===== Write Log and Tensorboard ====="""
            # stdout log
            if step % self.log_config.logging_step == 0:
                # reduce losses from GPUs
                if CONFIG.dist:
                    self.loss_dict = utils.reduce_tensor_dict(self.loss_dict,
                                                              mode='mean')
                    loss = utils.reduce_tensor(loss)
                # create logging information
                for loss_key in self.loss_dict.keys():
                    if self.loss_dict[loss_key] is not None:
                        log_info += loss_key.upper() + ": {:.4f}, ".format(
                            self.loss_dict[loss_key])

                self.logger.debug(
                    "Image tensor shape: {}. Trimap tensor shape: {}".format(
                        image.shape, trimap.shape))
                log_info = "[{}/{}], ".format(
                    step, self.train_config.total_step) + log_info
                log_info += "lr: {:6f}".format(cur_G_lr)
                self.logger.info(log_info)

                # tensorboard
                if step % self.log_config.tensorboard_step == 0 or step == start:  # and step > start:
                    self.tb_logger.scalar_summary('Loss', loss, step)

                    # detailed losses
                    for loss_key in self.loss_dict.keys():
                        if self.loss_dict[loss_key] is not None:
                            self.tb_logger.scalar_summary(
                                'Loss_' + loss_key.upper(),
                                self.loss_dict[loss_key], step)

                    self.tb_logger.scalar_summary('LearnRate', cur_G_lr, step)

                    if self.train_config.clip_grad:
                        self.tb_logger.scalar_summary('Moving_Max_Grad',
                                                      moving_max_grad, step)
                        self.tb_logger.scalar_summary('Max_Grad', max_grad,
                                                      step)
            """===== TEST ====="""
            if ((step % self.train_config.val_step) == 0 or step
                    == self.train_config.total_step):  # and step > start:
                self.G.eval()
                test_loss = 0
                log_info = ""

                self.test_loss_dict['mse'] = 0
                self.test_loss_dict['sad'] = 0
                for loss_key in self.loss_dict.keys():
                    if loss_key in self.test_loss_dict and self.loss_dict[
                            loss_key] is not None:
                        self.test_loss_dict[loss_key] = 0

                with torch.no_grad():
                    for image_dict in self.test_dataloader:
                        image, alpha, trimap, mask = image_dict[
                            'image'], image_dict['alpha'], image_dict[
                                'trimap'], image_dict['mask']
                        alpha_shape = image_dict['alpha_shape']
                        image = image.cuda()
                        alpha = alpha.cuda()
                        trimap = trimap.cuda()
                        mask = mask.cuda()

                        pred = self.G(image, mask)

                        alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred[
                            'alpha_os1'], pred['alpha_os4'], pred['alpha_os8']
                        alpha_pred = alpha_pred_os8.clone().detach()
                        weight_os4 = utils.get_unknown_tensor_from_pred(
                            alpha_pred,
                            rand_width=CONFIG.model.self_refine_width1,
                            train_mode=False)
                        alpha_pred[weight_os4 > 0] = alpha_pred_os4[
                            weight_os4 > 0]
                        weight_os1 = utils.get_unknown_tensor_from_pred(
                            alpha_pred,
                            rand_width=CONFIG.model.self_refine_width2,
                            train_mode=False)
                        alpha_pred[weight_os1 > 0] = alpha_pred_os1[
                            weight_os1 > 0]

                        h, w = alpha_shape
                        alpha_pred = alpha_pred[..., :h, :w]
                        trimap = trimap[..., :h, :w]

                        weight = utils.get_unknown_tensor(trimap)
                        weight[...] = 1

                        # value of MSE/SAD here is different from test.py and matlab version
                        self.test_loss_dict['mse'] += self.mse(
                            alpha_pred, alpha, weight)
                        self.test_loss_dict['sad'] += self.sad(
                            alpha_pred, alpha, weight)

                        if self.train_config.rec_weight > 0:
                            self.test_loss_dict['rec'] += self.regression_loss(alpha_pred, alpha, weight=weight) \
                                                          * self.train_config.rec_weight

                # reduce losses from GPUs
                if CONFIG.dist:
                    self.test_loss_dict = utils.reduce_tensor_dict(
                        self.test_loss_dict, mode='mean')
                """===== Write Log and Tensorboard ====="""
                # stdout log
                for loss_key in self.test_loss_dict.keys():
                    if self.test_loss_dict[loss_key] is not None:
                        self.test_loss_dict[loss_key] /= len(
                            self.test_dataloader)
                        # logging
                        log_info += loss_key.upper() + ": {:.4f} ".format(
                            self.test_loss_dict[loss_key])
                        self.tb_logger.scalar_summary(
                            'Loss_' + loss_key.upper(),
                            self.test_loss_dict[loss_key],
                            step,
                            phase='test')

                        if loss_key in ['rec']:
                            test_loss += self.test_loss_dict[loss_key]

                self.logger.info("TEST: LOSS: {:.4f} ".format(test_loss) +
                                 log_info)
                self.tb_logger.scalar_summary('Loss',
                                              test_loss,
                                              step,
                                              phase='test')

                # if self.model_config.trimap_channel == 3:
                #     trimap = trimap.argmax(dim=1, keepdim=True)
                # alpha_pred[trimap==2] = 1
                # alpha_pred[trimap==0] = 0
                image_set = {
                    'image':
                    (utils.normalize_image(image[-1, ...]).data.cpu().numpy() *
                     255).astype(np.uint8),
                    'mask':
                    (mask[-1, ...].data.cpu().numpy() * 255).astype(np.uint8),
                    'alpha':
                    (alpha[-1, ...].data.cpu().numpy() * 255).astype(np.uint8),
                    'alpha_pred': (alpha_pred[-1, ...].data.cpu().numpy() *
                                   255).astype(np.uint8)
                }

                self.tb_logger.image_summary(image_set, step, phase='test')
                """===== Save Model ====="""
                if (step % self.log_config.checkpoint_step == 0 or step == self.train_config.total_step) \
                        and CONFIG.local_rank == 0 and (step > start):
                    self.logger.info(
                        'Saving the trained models from step {}...'.format(
                            iter))
                    self.save_model("latest_model", step, loss)
                    if self.test_loss_dict['mse'] < self.best_loss:
                        self.best_loss = self.test_loss_dict['mse']
                        self.save_model("best_model", step, loss)

                torch.cuda.empty_cache()
示例#14
0
def finetune(args, train_loader, test_loader, model, criterion):
    train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler
    labeled_loader = DataLoader(train_loader.dataset,
                                sampler=train_sampler(train_loader.dataset),
                                batch_size=args.finetune_batch_size,
                                num_workers=args.workers,
                                pin_memory=True)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.finetune_lr,
                          momentum=args.finetune_momentum,
                          weight_decay=args.finetune_weight_decay)
    scaler = amp.GradScaler(enabled=args.amp)

    logger.info("***** Running Finetuning *****")
    logger.info(
        f"   Finetuning steps = {len(labeled_loader)*args.finetune_epochs}")

    for epoch in range(args.finetune_epochs):
        if args.world_size > 1:
            labeled_loader.sampler.set_epoch(epoch + 624)

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        model.train()
        end = time.time()
        labeled_iter = tqdm(labeled_loader,
                            disable=args.local_rank not in [-1, 0])
        for step, (images, targets) in enumerate(labeled_iter):
            data_time.update(time.time() - end)
            batch_size = targets.shape[0]
            images = images.to(args.device)
            targets = targets.to(args.device)
            with amp.autocast(enabled=args.amp):
                model.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if args.world_size > 1:
                loss = reduce_tensor(loss.detach(), args.world_size)
            losses.update(loss.item(), batch_size)
            batch_time.update(time.time() - end)
            labeled_iter.set_description(
                f"Finetune Epoch: {epoch+1:2}/{args.finetune_epochs:2}. Data: {data_time.avg:.2f}s. "
                f"Batch: {batch_time.avg:.2f}s. Loss: {losses.avg:.4f}. ")
        labeled_iter.close()
        if args.local_rank in [-1, 0]:
            args.writer.add_scalar("finetune/train_loss", losses.avg, epoch)
            test_loss, top1, top5 = evaluate(args, test_loader, model,
                                             criterion)
            args.writer.add_scalar("finetune/test_loss", test_loss, epoch)
            args.writer.add_scalar("finetune/acc@1", top1, epoch)
            args.writer.add_scalar("finetune/acc@5", top5, epoch)
            is_best = top1 > args.best_top1
            if is_best:
                args.best_top1 = top1
                args.best_top5 = top5

            logger.info(f"top-1 acc: {top1:.2f}")
            logger.info(f"Best top-1 acc: {args.best_top1:.2f}")

            save_checkpoint(args, {
                'step': step + 1,
                'best_top1': args.best_top1,
                'best_top5': args.best_top5,
                'student_state_dict': model.state_dict(),
                'avg_state_dict': None,
                'student_optimizer': optimizer.state_dict(),
            },
                            is_best,
                            finetune=True)
    return
示例#15
0
    def forward(self,
                x,
                x_noisy,
                std_in,
                opt,
                step=None,
                writer=None,
                init=False,
                valid=False):
        opt.zero_grad()
        batch_size = x.size(0)
        num_points = x.size(1)
        z_mu, z_sigma = self.encoder(x)
        if self.use_deterministic_encoder:
            z = z_mu + 0 * z_sigma
        else:
            z = self.reparameterize_gaussian(z_mu, z_sigma)

        # Compute H[Q(z|X)]
        if self.use_deterministic_encoder:
            entropy = torch.zeros(batch_size).to(z)
        else:
            entropy = self.gaussian_entropy(z_sigma)

        # Compute the prior probability P(z)
        w, delta_log_pw = self.latent_glow(z)
        log_pw = standard_normal_logprob(w).view(batch_size,
                                                 -1).sum(1, keepdim=True)
        delta_log_pw = delta_log_pw.view(batch_size, 1)
        log_pz = log_pw - delta_log_pw

        # Compute the reconstruction likelihood P(X|z)
        z_new = z.view(*z.size())
        z_new = z_new + (log_pz * 0.).mean()
        y, delta_log_py = self.point_AF(x_noisy, std_in, z_new)
        log_py = standard_normal_logprob(y).view(batch_size,
                                                 -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, num_points, 1).sum(1)
        log_px = log_py - delta_log_py

        # Loss
        entropy_loss = -entropy.mean()
        recon_loss = -log_px.mean()
        prior_loss = -log_pz.mean()
        loss = entropy_loss + prior_loss + recon_loss
        if not init and not valid:
            loss.backward()
            opt.step()

        # LOGGING (after the training)
        if self.distributed:
            loss = reduce_tensor(loss.mean())
            entropy_log = reduce_tensor(entropy.mean())
            recon = reduce_tensor(-log_px.mean())
            prior = reduce_tensor(-log_pz.mean())
        else:
            loss = loss.mean()
            entropy_log = entropy.mean()
            recon = -log_px.mean()
            prior = -log_pz.mean()

        recon_nats = recon / float(x.size(1) * x.size(2))
        prior_nats = prior / float(self.zdim)

        if writer is not None and not valid:
            writer.add_scalar('train/entropy', entropy_log, step)
            writer.add_scalar('train/prior', prior, step)
            writer.add_scalar('train/prior(nats)', prior_nats, step)
            writer.add_scalar('train/recon', recon, step)
            writer.add_scalar('train/recon(nats)', recon_nats, step)
            writer.add_scalar('train/loss', loss.item(), step)

        return {
            'entropy':
            entropy_log.cpu().detach().item()
            if not isinstance(entropy_log, float) else entropy_log,
            'prior_nats':
            prior_nats,
            'recon_nats':
            recon_nats,
            'prior':
            prior,
            'recon':
            recon,
            'loss':
            loss.item()
        }
示例#16
0
def train_loop(args, labeled_loader, unlabeled_loader, test_loader,
               teacher_model, student_model, avg_student_model, criterion,
               t_optimizer, s_optimizer, t_scheduler, s_scheduler, t_scaler,
               s_scaler):
    logger.info("***** Running Training *****")
    logger.info(f"   Task = {args.dataset}@{args.num_labeled}")
    logger.info(f"   Total steps = {args.total_steps}")

    if args.world_size > 1:
        labeled_epoch = 0
        unlabeled_epoch = 0
        labeled_loader.sampler.set_epoch(labeled_epoch)
        unlabeled_loader.sampler.set_epoch(unlabeled_epoch)

    labeled_iter = iter(labeled_loader)
    unlabeled_iter = iter(unlabeled_loader)

    moving_dot_product = torch.empty(1).to(args.device)
    limit = 3.0**(0.5)  # 3 = 6 / (f_in + f_out)
    nn.init.uniform_(moving_dot_product, -limit, limit)

    for step in range(args.start_step, args.total_steps):
        if step % args.eval_step == 0:
            pbar = tqdm(range(args.eval_step),
                        disable=args.local_rank not in [-1, 0])
            batch_time = AverageMeter()
            data_time = AverageMeter()
            s_losses = AverageMeter()
            t_losses = AverageMeter()
            t_losses_l = AverageMeter()
            t_losses_u = AverageMeter()
            t_losses_mpl = AverageMeter()
            mean_mask = AverageMeter()

        teacher_model.train()
        student_model.train()
        end = time.time()

        try:
            images_l, targets = labeled_iter.next()
        except:
            if args.world_size > 1:
                labeled_epoch += 1
                labeled_loader.sampler.set_epoch(labeled_epoch)
            labeled_iter = iter(labeled_loader)
            images_l, targets = labeled_iter.next()

        try:
            (images_uw, images_us), _ = unlabeled_iter.next()
        except:
            if args.world_size > 1:
                unlabeled_epoch += 1
                unlabeled_loader.sampler.set_epoch(unlabeled_epoch)
            unlabeled_iter = iter(unlabeled_loader)
            (images_uw, images_us), _ = unlabeled_iter.next()

        data_time.update(time.time() - end)

        images_l = images_l.to(args.device)
        images_uw = images_uw.to(args.device)
        images_us = images_us.to(args.device)
        targets = targets.to(args.device)
        with amp.autocast(enabled=args.amp):
            batch_size = images_l.shape[0]
            t_images = torch.cat((images_l, images_uw, images_us))
            t_logits = teacher_model(t_images)
            t_logits_l = t_logits[:batch_size]
            t_logits_uw, t_logits_us = t_logits[batch_size:].chunk(2)
            del t_logits

            t_loss_l = criterion(t_logits_l, targets)

            soft_pseudo_label = torch.softmax(t_logits_uw.detach() /
                                              args.temperature,
                                              dim=-1)
            max_probs, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1)
            mask = max_probs.ge(args.threshold).float()
            t_loss_u = torch.mean(
                -(soft_pseudo_label *
                  torch.log_softmax(t_logits_us, dim=-1)).sum(dim=-1) * mask)
            weight_u = args.lambda_u * min(1., (step + 1) / args.uda_steps)
            t_loss_uda = t_loss_l + weight_u * t_loss_u

            # s_images = torch.cat((images_l, images_us))
            # s_logits = student_model(s_images)
            # s_logits_l = s_logits[:batch_size]
            # s_logits_us = s_logits[batch_size:]
            s_logits_us = student_model(images_us)
            student_model.eval()
            with torch.no_grad():
                s_logits_l = student_model(images_l)
            student_model.train()
            # del s_logits

            s_loss_l_old = F.cross_entropy(s_logits_l.detach(), targets)
            s_loss = criterion(s_logits_us, hard_pseudo_label)

        s_scaler.scale(s_loss).backward()
        if args.grad_clip > 0:
            s_scaler.unscale_(s_optimizer)
            nn.utils.clip_grad_norm_(student_model.parameters(),
                                     args.grad_clip)
        s_scaler.step(s_optimizer)
        s_scaler.update()
        s_scheduler.step()
        if args.ema > 0:
            avg_student_model.update_parameters(student_model)

        with amp.autocast(enabled=args.amp):
            student_model.eval()
            with torch.no_grad():
                s_logits_l = student_model(images_l)
            student_model.train()
            s_loss_l_new = F.cross_entropy(s_logits_l.detach(), targets)
            dot_product = s_loss_l_new - s_loss_l_old
            # test
            # dot_product = s_loss_l_old - s_loss_l_new
            moving_dot_product = moving_dot_product * 0.99 + dot_product * 0.01
            dot_product = dot_product - moving_dot_product
            _, hard_pseudo_label = torch.max(t_logits_us.detach(), dim=-1)
            t_loss_mpl = dot_product * F.cross_entropy(t_logits_us,
                                                       hard_pseudo_label)
            t_loss = t_loss_uda + t_loss_mpl

        t_scaler.scale(t_loss).backward()
        if args.grad_clip > 0:
            t_scaler.unscale_(t_optimizer)
            nn.utils.clip_grad_norm_(teacher_model.parameters(),
                                     args.grad_clip)
        t_scaler.step(t_optimizer)
        t_scaler.update()
        t_scheduler.step()

        teacher_model.zero_grad()
        student_model.zero_grad()

        if args.world_size > 1:
            s_loss = reduce_tensor(s_loss.detach(), args.world_size)
            t_loss = reduce_tensor(t_loss.detach(), args.world_size)
            t_loss_l = reduce_tensor(t_loss_l.detach(), args.world_size)
            t_loss_u = reduce_tensor(t_loss_u.detach(), args.world_size)
            t_loss_mpl = reduce_tensor(t_loss_mpl.detach(), args.world_size)
            mask = reduce_tensor(mask, args.world_size)

        s_losses.update(s_loss.item())
        t_losses.update(t_loss.item())
        t_losses_l.update(t_loss_l.item())
        t_losses_u.update(t_loss_u.item())
        t_losses_mpl.update(t_loss_mpl.item())
        mean_mask.update(mask.mean().item())

        batch_time.update(time.time() - end)
        pbar.set_description(
            f"Train Iter: {step+1:3}/{args.total_steps:3}. "
            f"LR: {get_lr(s_optimizer):.4f}. Data: {data_time.avg:.2f}s. "
            f"Batch: {batch_time.avg:.2f}s. S_Loss: {s_losses.avg:.4f}. "
            f"T_Loss: {t_losses.avg:.4f}. Mask: {mean_mask.avg:.4f}. ")
        pbar.update()
        if args.local_rank in [-1, 0]:
            args.writer.add_scalar("lr", get_lr(s_optimizer), step)

        args.num_eval = step // args.eval_step
        if (step + 1) % args.eval_step == 0:
            pbar.close()
            if args.local_rank in [-1, 0]:
                args.writer.add_scalar("train/1.s_loss", s_losses.avg,
                                       args.num_eval)
                args.writer.add_scalar("train/2.t_loss", t_losses.avg,
                                       args.num_eval)
                args.writer.add_scalar("train/3.t_labeled", t_losses_l.avg,
                                       args.num_eval)
                args.writer.add_scalar("train/4.t_unlabeled", t_losses_u.avg,
                                       args.num_eval)
                args.writer.add_scalar("train/5.t_mpl", t_losses_mpl.avg,
                                       args.num_eval)
                args.writer.add_scalar("train/6.mask", mean_mask.avg,
                                       args.num_eval)

                test_model = avg_student_model if avg_student_model is not None else student_model
                test_loss, top1, top5 = evaluate(args, test_loader, test_model,
                                                 criterion)

                args.writer.add_scalar("test/loss", test_loss, args.num_eval)
                args.writer.add_scalar("test/acc@1", top1, args.num_eval)
                args.writer.add_scalar("test/acc@5", top5, args.num_eval)

                is_best = top1 > args.best_top1
                if is_best:
                    args.best_top1 = top1
                    args.best_top5 = top5

                logger.info(f"top-1 acc: {top1:.2f}")
                logger.info(f"Best top-1 acc: {args.best_top1:.2f}")

                save_checkpoint(
                    args, {
                        'step':
                        step + 1,
                        'teacher_state_dict':
                        teacher_model.state_dict(),
                        'student_state_dict':
                        student_model.state_dict(),
                        'avg_state_dict':
                        avg_student_model.state_dict()
                        if avg_student_model is not None else None,
                        'best_top1':
                        args.best_top1,
                        'best_top5':
                        args.best_top5,
                        'teacher_optimizer':
                        t_optimizer.state_dict(),
                        'student_optimizer':
                        s_optimizer.state_dict(),
                        'teacher_scheduler':
                        t_scheduler.state_dict(),
                        'student_scheduler':
                        s_scheduler.state_dict(),
                        'teacher_scaler':
                        t_scaler.state_dict(),
                        'student_scaler':
                        s_scaler.state_dict(),
                    }, is_best)
    # finetune
    del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader
    del s_scaler, s_scheduler, s_optimizer
    ckpt_name = f'{args.save_path}/{args.name}_best.pth.tar'
    loc = f'cuda:{args.gpu}'
    checkpoint = torch.load(ckpt_name, map_location=loc)
    logger.info(f"=> loading checkpoint '{ckpt_name}'")
    if checkpoint['avg_state_dict'] is not None:
        model_load_state_dict(student_model, checkpoint['avg_state_dict'])
    else:
        model_load_state_dict(student_model, checkpoint['student_state_dict'])
    finetune(args, labeled_loader, test_loader, student_model, criterion)
    return
示例#17
0
def fit(num_epoch=args['num_train_epochs']):
    global_step = 0
    model.train()
    for i_ in tqdm(range(int(num_epoch)), desc="Epoch"):
        print('当前阶段******************************', i_)
        tr_loss, tr_accuracy = 0, 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for index, batch in enumerate(tqdm(train_dataloader,
                                           desc="Iteration")):

            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch

            try:
                logits = model(input_ids, segment_ids, input_mask, label_ids)
                tmp_train_loss = loss_fct(logits.view(-1, num_labels),
                                          label_ids.squeeze())
                tmp_train_accuracy = accuracy(
                    logits.view(-1, num_labels).detach().cpu().numpy(),
                    label_ids.squeeze().detach().cpu().numpy())
                if n_gpu > 1:
                    tmp_train_loss = tmp_train_loss.mean(
                    )  # mean() to average on multi-gpu.

                if args["local_rank"] != -1:
                    tmp_train_loss = reduce_tensor(tmp_train_loss)
                    tmp_train_accuracy = reduce_tensor(
                        torch.tensor(tmp_train_accuracy).to(device))

                tmp_train_loss = tmp_train_loss / args[
                    'gradient_accumulation_steps']
                with amp.scale_loss(tmp_train_loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                # if args['fp16']:
                #     optimizer.backward(tmp_train_loss)
                # else:
                #     tmp_train_loss.backward()

                if (index + 1) % args['gradient_accumulation_steps'] == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                tr_loss += tmp_train_loss.item()
                tr_accuracy += tmp_train_accuracy.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                global_step += 1
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory')
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise e

            # Tensorboard Logging
            eval_loss, eval_accuracy = 0, 0
            if global_step % 100 == 0:
                eval_loss, eval_accuracy = eval()

                logger.info('tr_loss:{} & tr_accuracy:{}'.format(
                    tr_loss / nb_tr_steps, tr_accuracy / nb_tr_examples))
                logger.info('eval_loss:{} & eval_accuracy:{}'.format(
                    eval_loss, eval_accuracy))
                info = {
                    'tr_loss': tr_loss / nb_tr_steps,
                    'tr_accuracy': tr_accuracy / nb_tr_examples
                }
                for tag, value in info.items():
                    loggers.scalar_summary(tag, value, global_step + 1)
                info = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}
                for tag, value in info.items():
                    loggers.scalar_summary(tag, value, global_step + 1)

            # 将模型保存下来
            if global_step % 200 == 0:
                params.append(eval_accuracy)
                if eval_accuracy >= max(params):
                    if args["local_rank"] == -1:
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            model_path, "finetuned_pytorch_model.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                    elif args["local_rank"] == 0:
                        checkpoint = {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'amp': amp.state_dict()
                        }
                        output_model_file = os.path.join(
                            model_path, "amp_checkpoint.pt")
                        torch.save(checkpoint, output_model_file)
                    # model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                    # output_model_file = os.path.join(model_path, "checkpoint.pt")
                    # torch.save({
                    #     'model': model_to_save.state_dict()
                    # }, output_model_file)

        if args["fp16"]:
            #             scheduler.batch_step()
            # modify learning rate with special warm up BERT uses
            lr_this_step = args['learning_rate'] * warmup_linear(
                global_step / t_total, args['warmup_proportion'])
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
        else:
            scheduler.step()
示例#18
0
def ens_attack(val_loader, model, criterion, args, log, num_mc_samples=20):
    def _grad(X, y, mean, std):
        probs = torch.zeros(num_mc_samples, X.shape[0]).cuda(args.gpu)
        grads = torch.zeros(num_mc_samples, *list(X.shape)).cuda(args.gpu)
        for j in range(num_mc_samples):
            with model.no_sync():
                with torch.enable_grad():
                    X.requires_grad_()
                    output = model(X.sub(mean).div(std))
                    loss = torch.nn.functional.cross_entropy(output,
                                                             y,
                                                             reduction='none')
                    grad_ = torch.autograd.grad(
                        [loss], [X],
                        grad_outputs=torch.ones_like(loss),
                        retain_graph=False)[0].detach()
            grads[j] = grad_
            probs[j] = torch.gather(output.detach().softmax(-1), 1,
                                    y[:, None]).squeeze()
        probs /= probs.sum(0)
        grad_ = (grads * probs[:, :, None, None, None]).sum(0)
        return grad_

    def _pgd_whitebox(X, y, mean, std):
        X_pgd = X.clone()
        if args.random:
            X_pgd += torch.cuda.FloatTensor(*X_pgd.shape).uniform_(
                -args.epsilon, args.epsilon)

        for _ in range(args.num_steps):
            grad_ = _grad(X_pgd, y, mean, std)
            X_pgd += args.step_size * grad_.sign()
            eta = torch.clamp(X_pgd - X, -args.epsilon, args.epsilon)
            X_pgd = torch.clamp(X + eta, 0, 1.0)

        mis = 0
        preds = 0
        for ens in range(num_mc_samples):
            output = model(X_pgd.sub(mean).div(std))
            mis = (mis * ens + (-output.softmax(-1) *
                                (output).log_softmax(-1)).sum(1)) / (ens + 1)
            preds = (preds * ens + output.softmax(-1)) / (ens + 1)

        loss = criterion((preds + 1e-8).log(), target)
        prec1, prec5 = accuracy(preds, target, topk=(1, 5))
        mis = (-preds * (preds + 1e-8).log()).sum(1) - mis
        return loss, prec1, prec5, mis

    if args.dataset == 'cifar10':
        mean = torch.from_numpy(
            np.array([x / 255 for x in [125.3, 123.0, 113.9]
                      ])).view(1, 3, 1, 1).cuda(args.gpu).float()
        std = torch.from_numpy(np.array([x / 255 for x in [63.0, 62.1, 66.7]
                                         ])).view(1, 3, 1,
                                                  1).cuda(args.gpu).float()
    elif args.dataset == 'cifar100':
        mean = torch.from_numpy(
            np.array([x / 255 for x in [129.3, 124.1, 112.4]
                      ])).view(1, 3, 1, 1).cuda(args.gpu).float()
        std = torch.from_numpy(np.array([x / 255 for x in [68.2, 65.4, 70.4]
                                         ])).view(1, 3, 1,
                                                  1).cuda(args.gpu).float()
    elif args.dataset == 'imagenet':
        mean = torch.from_numpy(np.array([0.485, 0.456, 0.406
                                          ])).view(1, 3, 1,
                                                   1).cuda(args.gpu).float()
        std = torch.from_numpy(np.array([0.229, 0.224, 0.225
                                         ])).view(1, 3, 1,
                                                  1).cuda(args.gpu).float()

    losses, top1, top5 = 0, 0, 0
    model.eval()
    with torch.no_grad():
        mis = []
        for i, (input, target) in enumerate(val_loader):
            input = input.cuda(args.gpu,
                               non_blocking=True).mul_(std).add_(mean)
            target = target.cuda(args.gpu, non_blocking=True)
            loss, prec1, prec5, mis_ = _pgd_whitebox(input, target, mean, std)
            losses += loss * target.size(0)
            top1 += prec1 * target.size(0)
            top5 += prec5 * target.size(0)
            mis.append(mis_)

        mis = torch.cat(mis, 0)
        losses /= mis.size(0)
        top1 /= mis.size(0)
        top5 /= mis.size(0)
        losses = reduce_tensor(losses.data, args)
        top1 = reduce_tensor(top1.data, args)
        top5 = reduce_tensor(top5.data, args)

        if args.distributed: mis = dist_collect(mis)

    print_log(
        'ADV ensemble TOP1: {:.4f}, TOP5: {:.4f}, LOS: {:.4f}'.format(
            top1.item(), top5.item(), losses.item()), log)
    if args.gpu == 0:
        np.save(os.path.join(args.save_path, 'mis_advg.npy'),
                mis.data.cpu().numpy())
示例#19
0
def eval():
    args['output_dir'].mkdir(exist_ok=True)
    eval_features = convert_examples_to_features(eval_examples, label_list,
                                                 args['max_seq_length'],
                                                 tokenizer)
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args['eval_batch_size'])
    all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_ids for f in eval_features],
                                 dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                              all_label_ids)
    # Run prediction for full data
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args['eval_batch_size'])

    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)
        print("device:{}".format(device))
        with torch.no_grad():
            # tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
            logits = model(input_ids, segment_ids, input_mask)

        tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                 label_ids.squeeze())
        tmp_eval_accuracy = accuracy(
            logits.view(-1, num_labels).detach().cpu().numpy(),
            label_ids.squeeze().detach().cpu().numpy())

        if args["local_rank"] != -1:
            tmp_eval_loss = reduce_tensor(tmp_eval_loss)
            tmp_eval_accuracy = reduce_tensor(
                torch.tensor(tmp_eval_accuracy).to(device))

        eval_loss += tmp_eval_loss.mean().item()
        eval_accuracy += tmp_eval_accuracy.item()
        nb_eval_examples += input_ids.size(0)
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_examples

    result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}

    output_eval_file = os.path.join(args['output_dir'], "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return eval_loss, eval_accuracy
            data_time.update(time.time() - end)
            inputs = inputs.to(device)

            out, output_sizes = model(inputs, input_sizes)
            out = out.transpose(0, 1)  # TxNxH

            float_out = out.float()  # ensure float32 for loss
            #print(float_out.to('cpu'))
            #break
            loss = criterion(float_out.to('cpu'), targets, output_sizes,
                             target_sizes).to(device)
            loss = loss / inputs.size(0)  # average the loss by minibatchi

            if args.distributed:
                loss = loss.to(device)
                loss_value = reduce_tensor(loss, args.world_size).item()
            else:
                loss_value = loss.item()

            # Check to ensure valid loss was calculated
            valid_loss, error = check_loss(loss, loss_value)
            if valid_loss:
                optimizer.zero_grad()
                # compute gradient

                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_norm)
                optimizer.step()
                #if i%16 == 15:
示例#21
0
    def forward(self, x, step, writer=None):
        # x is (n, l, c)
        batch_size = x.size(0)
        num_points = x.size(1)
        z_mu, z_sigma = self.encoder(x)  # assume z_sigma is ln(sigma)
        if self.use_deterministic_encoder:
            z = z_mu + 0 * z_sigma
        else:
            z = self.reparameterize_gaussian(z_mu, z_sigma)

        # Compute H[Q(z|X)]
        if self.use_deterministic_encoder:
            entropy = torch.zeros(batch_size).to(z)
        else:
            entropy = self.gaussian_entropy(z_sigma)

        # Compute the prior probability P(z)
        if self.use_latent_flow:
            w, delta_log_pw = self.latent_rsf(z,
                                              torch.zeros(batch_size, 1).to(z))
            log_pw = standard_normal_logprob(w).view(batch_size,
                                                     -1).sum(1, keepdim=True)
            delta_log_pw = delta_log_pw.view(batch_size, 1)
            log_pz = log_pw - delta_log_pw
        else:
            log_pz = torch.zeros(batch_size, 1).to(z)

        # Compute the reconstruction likelihood P(X|z)
        # z_new = z.view(*z.size())
        # z_new = z_new + (log_pz * 0.).mean()
        y, delta_log_py = self.point_rsf(
            x,
            torch.zeros(batch_size, num_points, 1).to(x))
        log_py = standard_normal_logprob(y).view(batch_size,
                                                 -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, num_points, 1).sum(1)
        log_px = log_py - delta_log_py

        # Loss
        entropy_loss = -entropy.mean() * self.entropy_weight
        recon_loss = -log_px.mean() * self.recon_weight
        prior_loss = -log_pz.mean() * self.prior_weight
        loss = entropy_loss + prior_loss + recon_loss

        # LOGGING (after the training)
        if self.distributed:
            entropy_log = reduce_tensor(entropy.mean())
            recon = reduce_tensor(-log_px.mean())
            prior = reduce_tensor(-log_pz.mean())
        else:
            entropy_log = entropy.mean()
            recon = -log_px.mean()
            prior = -log_pz.mean()

        recon_nats = recon / float(x.size(1) * x.size(2))
        prior_nats = prior / float(self.zdim)

        if writer is not None:
            writer.add_scalar('train/entropy', entropy_log, step)
            writer.add_scalar('train/prior', prior, step)
            writer.add_scalar('train/prior(nats)', prior_nats, step)
            writer.add_scalar('train/recon', recon, step)
            writer.add_scalar('train/recon(nats)', recon_nats, step)

        return {
            'entropy':
            entropy_log.cpu().detach().item()
            if not isinstance(entropy_log, float) else entropy_log,
            'prior_nats':
            prior_nats,
            'recon_nats':
            recon_nats,
        }, loss