Exemplo n.º 1
0
def train(args, model, loader, optimizer, bce):
    model.train()
    epoch_loss = utils.AverageMeter()
    batch_loss = utils.AverageMeter()

    print_stats = len(loader) // 5
    for batch_idx, sample in enumerate(loader):
        data = sample['data'].float().cuda()
        descriptor = sample['descriptor'].cuda()
        target = sample['target'].float().cuda()
        target = torch.stack([1 - target, target], dim=1)

        optimizer.zero_grad()
        out = model(data, descriptor)
        loss = bce(out, target)
        if args.amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        batch_loss.update(loss.item())
        epoch_loss.update(loss.item())

        if batch_loss.count % print_stats == 0:
            text = '{} -- [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
            print(
                text.format(time.strftime("%H:%M:%S"), (batch_idx + 1),
                            (len(loader)),
                            100. * (batch_idx + 1) / (len(loader)),
                            batch_loss.avg))
            batch_loss.reset()
    print('--- Train: \tLoss: {:.6f} ---'.format(epoch_loss.avg))
    return epoch_loss.avg
Exemplo n.º 2
0
def train_epoch(train_loader, model, model_fn, optimizer, epoch):
    iter_time = utils.AverageMeter()
    data_time = utils.AverageMeter()
    am_dict = {}

    model.train()
    start_epoch = time.time()
    end = time.time()
    for i, batch in enumerate(train_loader):
        data_time.update(time.time() - end)
        torch.cuda.empty_cache()

        ##### adjust learning rate
        utils.step_learning_rate(optimizer, cfg.lr, epoch - 1, cfg.step_epoch, cfg.multiplier)

        ##### prepare input and forward
        loss, _, visual_dict, meter_dict = model_fn(batch, model, epoch)

        ##### meter_dict
        for k, v in meter_dict.items():
            if k not in am_dict.keys():
                am_dict[k] = utils.AverageMeter()
            am_dict[k].update(v[0], v[1])

        ##### backward
        optimizer.zero_grad()
        loss.backward()
        if cfg.clip:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip)
        optimizer.step()

        ##### time and print
        current_iter = (epoch - 1) * len(train_loader) + i + 1
        max_iter = cfg.epochs * len(train_loader)
        remain_iter = max_iter - current_iter

        iter_time.update(time.time() - end)
        end = time.time()

        remain_time = remain_iter * iter_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        sys.stdout.write(
            "epoch: {}/{} iter: {}/{} loss: {:.6f}({:.6f}) data_time: {:.2f}({:.2f}) iter_time: {:.2f}({:.2f}) remain_time: {remain_time}\n".format
            (epoch, cfg.epochs, i + 1, len(train_loader), am_dict['loss'].val, am_dict['loss'].avg,
             data_time.val, data_time.avg, iter_time.val, iter_time.avg, remain_time=remain_time))
        if (i == len(train_loader) - 1): print()


    logger.info("epoch: {}/{}, train loss: {:.4f}, time: {}s".format(epoch, cfg.epochs, am_dict['loss'].avg, time.time() - start_epoch))

    f = utils.checkpoint_save(model, cfg.exp_path, cfg.config.split('/')[-1][:-5], epoch, cfg.save_freq)
    logger.info('Saving {}'.format(f))

    for k in am_dict.keys():
        if k in visual_dict.keys():
            writer.add_scalar(k+'_train', am_dict[k].avg, epoch)
Exemplo n.º 3
0
def eval_epoch(val_loader, model, model_fn, epoch):
    logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
    am_dict = {}

    with torch.no_grad():
        model.eval()
        start_epoch = time.time()
        for i, batch in enumerate(val_loader):

            ##### prepare input and forward
            loss, preds, visual_dict, meter_dict = model_fn(batch, model, epoch)

            ##### meter_dict
            for k, v in meter_dict.items():
                if k not in am_dict.keys():
                    am_dict[k] = utils.AverageMeter()
                am_dict[k].update(v[0], v[1])

            ##### print
            sys.stdout.write("\riter: {}/{} loss: {:.4f}({:.4f})".format(i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
            if (i == len(val_loader) - 1): print()

        logger.info("epoch: {}/{}, val loss: {:.4f}, time: {}s".format(epoch, cfg.epochs, am_dict['loss'].avg, time.time() - start_epoch))

        for k in am_dict.keys():
            if k in visual_dict.keys():
                writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
Exemplo n.º 4
0
def eval_epoch(val_loader, model, model_fn, epoch):
    if cfg.local_rank == 0:
        logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
    am_dict = {}

    with torch.no_grad():
        model.eval()
        start_epoch = time.time()
        for i, batch in enumerate(val_loader):

            ##### prepare input and forward
            loss, preds, visual_dict, meter_dict = model_fn(
                batch, model, epoch)

            ##### merge(allreduce) multi-gpu
            if cfg.dist:
                for k, v in visual_dict.items():
                    count = meter_dict[k][1]
                    # print("[PID {}] Before allreduce: key {} value {} count {}".format(os.getpid(), k, float(v), count))

                    v = v * count
                    count = loss.new_tensor([int(count)], dtype=torch.long)
                    dist.all_reduce(v), dist.all_reduce(count)
                    count = count.item()
                    v = v / count
                    # print("[PID {}] After allreduce: key {} value {} count {}".format(os.getpid(), k, float(v), count))

                    visual_dict[k] = v
                    meter_dict[k] = (float(v), count)

            ##### meter_dict
            for k, v in meter_dict.items():
                if k not in am_dict.keys():
                    am_dict[k] = utils.AverageMeter()
                am_dict[k].update(v[0], v[1])

            ##### print
            if cfg.local_rank == 0:
                sys.stdout.write("\riter: {}/{} loss: {:.4f}({:.4f})".format(
                    i + 1, len(val_loader), am_dict['loss'].val,
                    am_dict['loss'].avg))
                if (i == len(val_loader) - 1): print()

        if cfg.local_rank == 0:
            logger.info("epoch: {}/{}, val loss: {:.4f}, time: {}s".format(
                epoch, cfg.epochs, am_dict['loss'].avg,
                time.time() - start_epoch))

            for k in am_dict.keys():
                if k in visual_dict.keys():
                    writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
Exemplo n.º 5
0
def test(args, model, loader, save_path, bce, training=True):
    model.eval()
    epoch_loss = utils.AverageMeter()
    count, correct = 0, 0
    labels, patients, scores, predictions = [], [], [], []

    for batch_idx, sample in enumerate(loader):
        data = sample['data'].float().cuda()
        descriptor = sample['descriptor'].float().cuda()
        target = sample['target'].float().cuda()
        patients.extend(sample['id'].tolist())
        labels.extend(sample['target'].tolist())

        with torch.no_grad():
            out = model(data, descriptor)
        loss = bce(out, torch.stack([1 - target, target], dim=1))
        epoch_loss.update(loss.item())

        confidence = F.softmax(out, dim=1)
        scores.extend(confidence[:, 1].tolist())

        pred = torch.argmax(confidence, dim=1)
        predictions.extend(pred.tolist())
        count += pred.sum()
        correct += (pred * target).sum()

    print('--- Val: \tLoss: {:.6f} ---'.format(epoch_loss.avg))

    # Metrics
    roc = roc_auc_score(labels, scores)
    ap = average_precision_score(labels, scores)
    f1 = f1_score(labels, predictions)

    if not training:
        print('ROC', roc, 'AP', ap, 'F1', f1)
        rows = zip(patients, scores)
        with open(os.path.join(save_path, 'confidence.csv'), "w") as f:
            writer = csv.writer(f)
            writer.writerow(['ROC:', roc])
            writer.writerow(['AP:', ap])
            writer.writerow(['F1:', f1])
            for row in rows:
                writer.writerow(row)

    count = count.sum()
    flag = True
    if count == 0 or count == len(loader.dataset):
        flag = False
    return epoch_loss.avg, f1, flag
Exemplo n.º 6
0
    def reset_counter(self):
        """Resets counters."""
        self.count = 0
        self.loss = utils.AverageMeter()
        self.accuracy = utils.AverageMeter()
        self.loss_diff_sign = utils.AverageMeter()
        if self.latency_cost:
            self.latency_pred_loss = utils.AverageMeter()
            self.latency_value = utils.AverageMeter()

        if self.meta_loss == 'relax':
            self.relax_pred_loss = utils.AverageMeter()
Exemplo n.º 7
0
    def __init__(self, model, alpha, args, writer, logging):
        self.args = args
        self.model = model
        self.alpha = alpha
        self.logging = logging
        self.arch_optimizer = torch.optim.Adam(
            self.alpha.parameters(),
            lr=args.arch_learning_rate,
            weight_decay=args.arch_weight_decay)

        self.arch_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.arch_optimizer,
            float(args.epochs),
            eta_min=args.arch_learning_rate_min)

        self.latency_cost = args.target_latency > 0.
        self.target_latency = args.target_latency
        if self.latency_cost or self.args.meta_loss == 'relax':
            assert args.meta_loss in {
                'relax', 'rebar', 'reinforce'
            }, 'this is only implemented for rebar and reinforce'
            normal_size, reduce_size = self.alpha.module.alphas_size()
            alpha_size = normal_size + reduce_size
            self.surrogate = SurrogateLinear(alpha_size, self.logging).cuda()
            self.latency_pred_loss = utils.AverageMeter()
            self.latency_value = utils.AverageMeter()
            self.latency_coeff = args.latency_coeff
            self.latency_coeff_curr = None
            self.num_repeat = 10
            self.latency_batch_size = 24
            assert self.latency_batch_size <= args.batch_size
            self.num_arch_samples = 10000
            # print('***************** change the number of samples *******')
            # self.num_arch_samples = 200
            self.surrogate_not_train = True

            self.latency_actual = []
            self.latency_estimate = []

        # Extra layers, if any.
        self.meta_loss = args.meta_loss

        # weights generalization error
        self.gen_error_alpha = args.gen_error_alpha
        self.gen_error_alpha_lambda = args.gen_error_alpha_lambda

        # Get the meta learning criterion.
        if self.meta_loss in ['default', 'rebar', 'reinforce']:
            self.criterion = nn.CrossEntropyLoss(reduction='none')
            self.criterion = self.criterion.cuda()

        if self.meta_loss == 'reinforce':
            self.exp_avg1 = utils.ExpMovingAvgrageMeter()
            self.exp_avg2 = utils.ExpMovingAvgrageMeter()

        self.alpha_loss = args.alpha_loss

        # Housekeeping.
        self.loss = None
        self.accuracy = None
        self.count = None
        self.loss_diff_sign = None
        self.reset_counter()
        self.report_freq = args.report_freq
        self.writer = writer