Ejemplo n.º 1
0
def evaluate(dataloader, model, topk=(1,)):
    """

    :param dataloader:
    :param model:
    :param topk: [tuple]          output the top topk accuracy
    :return:     [list[float]]    topk accuracy
    """
    model.eval()
    test_accuracy = AverageMeter()
    test_accuracy.reset()

    with torch.no_grad():
        for x, y, _ in dataloader:
            x = x.cuda()
            y = y.cuda()
            logits = model(x)

            acc = accuracy(logits, y, topk)
            test_accuracy.update(acc[0], x.size(0))

    return test_accuracy.avg
Ejemplo n.º 2
0
def main(seed=25):
    seed_everything(25)
    device = torch.device('cuda:0')

    # arguments
    args = Args().parse()
    n_class = args.n_class

    img_path_train = args.img_path_train
    mask_path_train = args.mask_path_train
    img_path_val = args.img_path_val
    mask_path_val = args.mask_path_val

    model_path = os.path.join(args.model_path, args.task_name)  # save model
    log_path = args.log_path
    output_path = args.output_path

    if not os.path.exists(model_path):
        os.makedirs(model_path)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    task_name = args.task_name
    print(task_name)
    ###################################
    evaluation = args.evaluation
    test = evaluation and False
    print("evaluation:", evaluation, "test:", test)

    ###################################
    print("preparing datasets and dataloaders......")
    batch_size = args.batch_size
    num_workers = args.num_workers
    config = args.config

    data_time = AverageMeter("DataTime", ':3.3f')
    batch_time = AverageMeter("BatchTime", ':3.3f')

    dataset_train = DoiDataset(img_path_train,
                               config,
                               train=True,
                               root_mask=mask_path_train)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
    dataset_val = DoiDataset(img_path_val,
                             config,
                             train=True,
                             root_mask=mask_path_val)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)

    ###################################
    print("creating models......")
    model = DoiNet(n_class, config['min_descriptor'] + 6, 4)
    model = create_model_load_weights(model,
                                      evaluation=False,
                                      ckpt_path=args.ckpt_path)
    model.to(device)

    ###################################
    num_epochs = args.epochs
    learning_rate = args.lr

    optimizer = get_optimizer(model, learning_rate=learning_rate)
    scheduler = LR_Scheduler(args.scheduler, learning_rate, num_epochs,
                             len(dataloader_train))
    ##################################
    criterion_node = nn.CrossEntropyLoss()
    criterion_edge = nn.BCELoss()
    alpha = args.alpha

    writer = SummaryWriter(log_dir=log_path + task_name)
    f_log = open(log_path + task_name + ".log", 'w')
    #######################################
    trainer = Trainer(criterion_node,
                      criterion_edge,
                      optimizer,
                      n_class,
                      device,
                      alpha=alpha)
    evaluator = Evaluator(n_class, device)

    best_pred = 0.0
    print("start training......")
    log = task_name + '\n'
    for k, v in args.__dict__.items():
        log += str(k) + ' = ' + str(v) + '\n'
    print(log)
    f_log.write(log)
    f_log.flush()

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        tbar = tqdm(dataloader_train)
        train_loss = 0
        train_loss_edge = 0
        train_loss_node = 0

        start_time = time.time()
        for i_batch, sample in enumerate(tbar):
            data_time.update(time.time() - start_time)

            if evaluation:  # evaluation pattern: no training
                break
            scheduler(optimizer, i_batch, epoch, best_pred)
            loss, loss_node, loss_edge = trainer.train(sample, model)
            train_loss += loss.item()
            train_loss_node += loss_node.item()
            train_loss_edge += loss_edge.item()
            train_scores_node, train_scores_edge = trainer.get_scores()

            batch_time.update(time.time() - start_time)
            start_time = time.time()

            if i_batch % 2 == 0:
                tbar.set_description(
                    'Train loss: %.4f (loss_node=%.4f  loss_edge=%.4f); F1 node: %.4f  F1 edge: %.4f; data time: %.2f; batch time: %.2f'
                    % (train_loss / (i_batch + 1), train_loss_node /
                       (i_batch + 1), train_loss_edge /
                       (i_batch + 1), train_scores_node["macro_f1"],
                       train_scores_edge["macro_f1"], data_time.avg,
                       batch_time.avg))

        trainer.reset_metrics()
        data_time.reset()
        batch_time.reset()

        if epoch % 1 == 0:
            with torch.no_grad():
                model.eval()
                print("evaluating...")

                tbar = tqdm(dataloader_val)
                start_time = time.time()
                for i_batch, sample in enumerate(tbar):
                    data_time.update(time.time() - start_time)
                    pred_node, pred_edge = evaluator.eval(sample, model)
                    val_scores_node, val_scores_edge = evaluator.get_scores()

                    batch_time.update(time.time() - start_time)
                    tbar.set_description(
                        'F1 node: %.4f  F1 edge: %.4f; data time: %.2f; batch time: %.2f'
                        % (val_scores_node["macro_f1"],
                           val_scores_edge["macro_f1"], data_time.avg,
                           batch_time.avg))
                    start_time = time.time()

            data_time.reset()
            batch_time.reset()
            val_scores_node, val_scores_node = evaluator.get_scores()
            evaluator.reset_metrics()

            best_pred = save_model(model, model_path, val_scores_node,
                                   val_scores_edge, alpha, task_name, epoch,
                                   best_pred)
            write_log(f_log, train_scores_node, train_scores_edge,
                      val_scores_node, val_scores_edge, epoch, num_epochs)
            write_summaryWriter(writer, train_loss / len(dataloader_train),
                                optimizer, train_scores_node,
                                train_scores_edge, val_scores_node,
                                val_scores_edge, epoch)

    f_log.close()
Ejemplo n.º 3
0
def main():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    print '===========training options=========='
    #print opt

    ###def model for train
    model = P3dModel()
    model.initialize(opt)

    ### def train continue or fineture from pretrained model
    if opt.continue_train:
        model.load(fineture=False, which_epoch=start_epoch - 1, pretrain='')
    else:
        if opt.modality == 'RGB':
            pretrained_file = 'p3d_rgb_199.checkpoint.pth.tar'
        elif opt.modality == 'Flow':
            pretrained_file = 'p3d_flow_199.checkpoint.pth.tar'
        model.load(fineture=True, pretrain=pretrained_file)

    print '%s is useing' % (model.name())

    #def vis
    Visual = Visualizer(opt)
    ##
    dummy_input = torch.rand(1, 3, 16, 224, 224).cuda()
    Visual.tbx_write_net(model.model, dummy_input)

    ### def all data loader
    train_data_loader, val_data_loader, dataset_size, _ = CreateDataLoader(
        opt, model)

    print '#training images = %d' % (dataset_size)
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    cudnn.benchmark = True

    #def metrics
    top1 = AverageMeter()
    top5 = AverageMeter()
    losses = AverageMeter()

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    print_delta = total_steps % opt.print_freq
    update_size = opt.larger_batch_size // opt.batch_size

    update_num = 0
    model.module.optimizer.zero_grad()
    for epoch in range(start_epoch, opt.epochs + 1):
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size

        top1.reset()
        top5.reset()
        losses.reset()
        model.train()
        for i, data in enumerate(train_data_loader, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size

            #print data['data'].shape
            #print data['label'].shape

            input = Variable(data['data'].cuda())
            label = Variable(data['label'].cuda())
            ############## Forward Pass ######################
            pred, loss = model(
                input,
                label,
            )

            #need mean or not/
            # loss=torch.mean(loss)

            pt1, pt5, _ = accuracy(pred.data,
                                   data['label'].cuda(),
                                   topk=(1, 5))
            top1.update(pt1.item(), input.size(0))
            top5.update(pt5.item(), input.size(0))
            losses.update(loss.item(), input.size(0))
            ############### Backward Pass ####################
            # update model weights

            loss.backward()
            update_num += 1
            if update_num == update_size:

                model.module.optimizer.step()
                model.module.optimizer.zero_grad()
                update_num = 0

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    'train_loss': losses.get(),
                    'top1': top1.get(),
                    'top5': top5.get()
                }
                t = (time.time() - iter_start_time) / opt.batch_size
                Visual.print_current_errors(epoch, epoch_iter, errors, t)
                Visual.tbx_write_errors(errors, total_steps, 'Train/loss')
                top1.reset()
                top5.reset()
                losses.reset()

            if epoch_iter >= dataset_size:
                break

        ### save model for this epoch
        ##valid here
        top1.reset()
        top5.reset()
        losses.reset()
        v__start_time = time.time()
        this_iter = 0
        model.eval()
        with torch.no_grad():
            for i, data in enumerate(val_data_loader):

                input = Variable(data['data'].cuda())
                label = Variable(data['label'].cuda())

                pred, loss = model(
                    input,
                    label,
                )

                pt1, pt5, _ = accuracy(pred.data,
                                       data['label'].cuda(),
                                       topk=(1, 5))
                top1.update(pt1.item(), input.size(0))
                top5.update(pt5.item(), input.size(0))
                losses.update(loss.item(), input.size(0))
                this_iter = this_iter + opt.batch_size

            errors = {
                'valid_loss': losses.get(),
                'top1': top1.get(),
                'top5': top5.get()
            }
            t = (time.time() - v__start_time) / opt.batch_size
            Visual.print_current_errors(epoch, this_iter, errors, t)
            Visual.tbx_write_errors(errors, total_steps, 'Valid/loss')
            top1.reset()
            top5.reset()
            losses.reset()

        if epoch % opt.save_epoch_freq == 0:

            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### linearly decay learning rate after certain iterations
        if epoch % opt.lr_decay_epoch == 0:
            model.module.update_learning_rate()
Ejemplo n.º 4
0
def main():
    opt = TestOptions().parse(False)
    channel = 3
    if opt.modality == 'RGB':
        channel = 3
        data_length = 16
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int)

    ###def model for train
    model = P3dModel()
    model.initialize(opt)
    num_classes = model.num_classes
    which_epoch = start_epoch - 1
    if opt.epoch_num > 0:
        which_epoch = opt.epoch_num
    model.load(fineture=False, which_epoch=which_epoch, pretrain='')

    print '%s is useing' % (model.name())

    ### def all data loader
    test_data_loader, dataset_size = CreateTestLoader(opt, model)

    print '#testing images = %d' % (dataset_size)
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    cudnn.benchmark = True

    #def metrics
    top1 = AverageMeter()
    top5 = AverageMeter()
    class_top1 = [AverageMeter() for i in range(num_classes)]
    class_top5 = [AverageMeter() for i in range(num_classes)]
    mix_m = np.zeros((num_classes, num_classes), np.float32)
    top1.reset()
    top5.reset()

    for i in range(num_classes):
        class_top1[i].reset()
        class_top5[i].reset()

    v__start_time = time.time()
    this_iter = 0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(test_data_loader):
            print '%d video isprocessing' % (i)
            input = Variable(data['data'].cuda())
            label = Variable(data['label'].cuda())

            sizes = input.size()
            assert sizes[
                2] == opt.crop_num * opt.num_segments * data_length, 'shape error'

            input = input.view(sizes[0], sizes[1],
                               opt.crop_num * opt.num_segments, data_length,
                               sizes[3], sizes[4])
            input = input.permute(0, 2, 1, 3, 4, 5)
            input = input.view(sizes[0] * opt.crop_num * opt.num_segments,
                               sizes[1], data_length, sizes[3], sizes[4])
            pred = model.module.inference(input)
            pred = pred.view(opt.batch_size, opt.crop_num * opt.num_segments,
                             -1)
            new_pred = torch.sum(pred.data, 1, False)

            pt1, pt5, acc_v = accuracy(new_pred,
                                       data['label'].cuda(),
                                       topk=(1, 5))
            top1.update(pt1.item(), opt.batch_size)
            top5.update(pt5.item(), opt.batch_size)

            d = data['label'].numpy()
            assert d.shape[0] == 1, 'only support batch size ==1'
            key = d[0]
            class_top1[key].update(pt1.item(), opt.batch_size)
            class_top5[key].update(pt5.item(), opt.batch_size)

            cc = acc_v[0]
            mix_m[key, cc] += 1
            this_iter = this_iter + opt.batch_size

            #if i>100:
            #   break

        t = (time.time() - v__start_time) / opt.batch_size
        print '%f times test result top5: %f, top1: %f' % (t, top5.get(),
                                                           top1.get())

        total_loss_file = '%s/total_loss.txt' % (opt.checkpoints_dir)
        message = '%s_nseg_%d_ncrop_%d_nepoch_%d : ' % (
            opt.name, opt.num_segments, opt.crop_num, opt.epoch_num)

        with open(total_loss_file, "a") as tlf:
            message += '%f times test result top5: %f, top1: %f \n' % (
                t, top5.get(), top1.get())
            tlf.write('%s ' % message)

        print '==============each class accuracy========================='
        for i in range(num_classes):
            print 'class %d test result top5: %f, top1: %f' % (
                i, class_top5[i].get(), class_top1[i].get())
            mix_m[i, :] /= class_top1[i].count

        plot_mix(mix_m, opt)

        top1.reset()
        top5.reset()
        for i in range(num_classes):
            class_top1[i].reset()
            class_top5[i].reset()
Ejemplo n.º 5
0
class Model:
    def __init__(self, args):
        # common args
        self.args = args
        self.best_miou = -1.0
        self.dataset_name = args.dataset_name
        self.debug = args.debug
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu:0")
        self.dir_checkpoints = f"{args.dir_root}/checkpoints/{args.experim_name}"
        self.experim_name = args.experim_name
        self.ignore_index = args.ignore_index
        self.init_n_pixels = args.n_init_pixels
        self.max_budget = args.max_budget
        self.n_classes = args.n_classes
        self.n_epochs = args.n_epochs
        self.n_pixels_by_us = args.n_pixels_by_us
        self.network_name = args.network_name
        self.nth_query = -1
        self.stride_total = args.stride_total

        self.dataloader = get_dataloader(deepcopy(args),
                                         val=False,
                                         query=False,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         n_workers=args.n_workers)
        self.dataloader_query = get_dataloader(deepcopy(args),
                                               val=False,
                                               query=True,
                                               shuffle=False,
                                               batch_size=1,
                                               n_workers=args.n_workers)
        self.dataloader_val = get_dataloader(deepcopy(args),
                                             val=True,
                                             query=False,
                                             shuffle=False,
                                             batch_size=1,
                                             n_workers=args.n_workers)

        self.model = get_model(args).to(self.device)

        self.lr_scheduler_type = args.lr_scheduler_type
        self.query_selector = QuerySelector(args, self.dataloader_query)
        self.vis = Visualiser(args.dataset_name)
        # for tracking stats
        self.running_loss, self.running_score = AverageMeter(), RunningScore(
            args.n_classes)

        # if active learning
        # if self.n_pixels_by_us > 0:
        #     self.model_0_query = f"{self.dir_checkpoints}/0_query_{args.seed}.pt"

    def __call__(self):
        # fully-supervised model
        if self.n_pixels_by_us == 0:
            dir_checkpoints = f"{self.dir_checkpoints}/fully_sup"
            os.makedirs(f"{dir_checkpoints}", exist_ok=True)

            self.log_train, self.log_val = f"{dir_checkpoints}/log_train.txt", f"{dir_checkpoints}/log_val.txt"
            write_log(f"{self.log_train}",
                      header=["epoch", "mIoU", "pixel_acc", "loss"])
            write_log(f"{self.log_val}", header=["epoch", "mIoU", "pixel_acc"])

            self._train()

        # active learning model
        else:
            n_stages = self.max_budget // self.n_pixels_by_us
            n_stages += 1 if self.init_n_pixels > 0 else 0
            print("n_stages:", n_stages)
            for nth_query in range(n_stages):
                dir_checkpoints = f"{self.dir_checkpoints}/{nth_query}_query"
                os.makedirs(f"{dir_checkpoints}", exist_ok=True)

                self.log_train, self.log_val = f"{dir_checkpoints}/log_train.txt", f"{dir_checkpoints}/log_val.txt"
                write_log(f"{self.log_train}",
                          header=["epoch", "mIoU", "pixel_acc", "loss"])
                write_log(f"{self.log_val}",
                          header=["epoch", "mIoU", "pixel_acc"])

                self.nth_query = nth_query

                model = self._train()

                # select queries using the current model and label them.
                queries = self.query_selector(nth_query, model)
                self.dataloader.dataset.label_queries(queries, nth_query + 1)

                if nth_query == n_stages - 1:
                    break

                # if nth_query == 0:
                #     torch.save({"model": model.state_dict()}, self.model_0_query)
        return

    def _train_epoch(self, epoch, model, optimizer, lr_scheduler):
        if self.n_pixels_by_us != 0:
            print(
                f"training an epoch {epoch} of {self.nth_query}th query ({self.dataloader.dataset.n_pixels_total} labelled pixels)"
            )
            fp = f"{self.dir_checkpoints}/{self.nth_query}_query/{epoch}_train.png"
        else:
            fp = f"{self.dir_checkpoints}/fully_sup/{epoch}_train.png"
        log = f"{self.log_train}"

        dataloader_iter, tbar = iter(self.dataloader), tqdm(
            range(len(self.dataloader)))
        model.train()
        for _ in tbar:
            dict_data = next(dataloader_iter)
            x, y = dict_data['x'].to(self.device), dict_data['y'].to(
                self.device)

            # if queries
            if self.n_pixels_by_us != 0:
                mask = dict_data['queries'].to(self.device, torch.bool)
                y.flatten()[~mask.flatten()] = self.ignore_index

            # forward pass
            dict_outputs = model(x)

            logits = dict_outputs["pred"]
            dict_losses = {
                "ce": F.cross_entropy(logits,
                                      y,
                                      ignore_index=self.ignore_index)
            }

            # backward pass
            loss = sum(dict_losses.values())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            prob, pred = F.softmax(logits.detach(),
                                   dim=1), logits.argmax(dim=1)
            self.running_score.update(y.cpu().numpy(), pred.cpu().numpy())
            self.running_loss.update(loss.detach().item())

            scores = self.running_score.get_scores()[0]
            miou, pixel_acc = scores['Mean IoU'], scores['Pixel Acc']

            # description
            description = f"({self.experim_name}) Epoch {epoch} | mIoU.: {miou:.3f} | pixel acc.: {pixel_acc:.3f} | " \
                          f"avg loss: {self.running_loss.avg:.3f}"
            for loss_k, loss_v in dict_losses.items():
                description += f" | {loss_k}: {loss_v.detach().cpu().item():.3f}"
            tbar.set_description(description)

            if self.lr_scheduler_type == "Poly":
                lr_scheduler.step(epoch=epoch - 1)

            if self.debug:
                break

        if self.lr_scheduler_type == "MultiStepLR":
            lr_scheduler.step(epoch=epoch - 1)

        write_log(
            log, list_entities=[epoch, miou, pixel_acc, self.running_loss.avg])
        self._reset_meters()

        ent, lc, ms, = [
            self._query(prob, uc)[0].cpu()
            for uc in ["entropy", "least_confidence", "margin_sampling"]
        ]
        dict_tensors = {
            'input': dict_data['x'][0].cpu(),
            'target': dict_data['y'][0].cpu(),
            'pred': pred[0].detach().cpu(),
            'confidence': lc,
            'margin':
            -ms,  # minus sign is to draw smaller margin part brighter
            'entropy': ent
        }

        self.vis(dict_tensors, fp=fp)
        return model, optimizer, lr_scheduler

    def _train(self):
        print(f"\n({self.experim_name}) training...\n")
        model = get_model(self.args).to(self.device)
        optimizer = get_optimizer(self.args, model)
        lr_scheduler = get_lr_scheduler(self.args,
                                        optimizer=optimizer,
                                        iters_per_epoch=len(self.dataloader))

        for e in range(1, 1 + self.n_epochs):
            model, optimizer, lr_scheduler = self._train_epoch(
                e, model, optimizer, lr_scheduler)
            self._val(e, model)

            if self.debug:
                break

        self.best_miou = -1.0
        return model

    @torch.no_grad()
    def _val(self, epoch, model):
        dataloader_iter, tbar = iter(self.dataloader_val), tqdm(
            range(len(self.dataloader_val)))
        model.eval()
        for _ in tbar:
            dict_data = next(dataloader_iter)
            x, y = dict_data['x'].to(self.device), dict_data['y'].to(
                self.device)

            if self.dataset_name == "voc":
                h, w = y.shape[1:]
                pad_h = ceil(
                    h / self.stride_total) * self.stride_total - x.shape[2]
                pad_w = ceil(
                    w / self.stride_total) * self.stride_total - x.shape[3]
                x = F.pad(x, pad=(0, pad_w, 0, pad_h), mode='reflect')
                dict_outputs = model(x)
                dict_outputs['pred'] = dict_outputs['pred'][:, :, :h, :w]

            else:
                dict_outputs = model(x)

            logits = dict_outputs['pred']
            prob, pred = F.softmax(logits.detach(),
                                   dim=1), logits.argmax(dim=1)

            self.running_score.update(y.cpu().numpy(), pred.cpu().numpy())
            scores = self.running_score.get_scores()[0]
            miou, pixel_acc = scores['Mean IoU'], scores['Pixel Acc']
            tbar.set_description(
                f"mIoU: {miou:.3f} | pixel acc.: {pixel_acc:.3f}")

            if self.debug:
                break

        if miou > self.best_miou:
            state_dict = {"model": model.state_dict()}

            if self.n_pixels_by_us != 0:
                torch.save(
                    state_dict,
                    f"{self.dir_checkpoints}/{self.nth_query}_query/best_miou_model.pt"
                )
            else:
                torch.save(
                    state_dict,
                    f"{self.dir_checkpoints}/fully_sup/best_miou_model.pt")
            print(
                f"best model has been saved"
                f"(epoch: {epoch} | prev. miou: {self.best_miou:.4f} => new miou: {miou:.4f})."
            )
            self.best_miou = miou

        write_log(self.log_val, list_entities=[epoch, miou, pixel_acc])

        print(
            f"\n{'=' * 100}"
            f"\nExperim name: {self.experim_name}"
            f"\nEpoch {epoch} | miou: {miou:.3f} | pixel_acc.: {pixel_acc:.3f}"
            f"\n{'=' * 100}\n")

        self._reset_meters()

        ent, lc, ms, = [
            self._query(prob, uc)[0].cpu()
            for uc in ["entropy", "least_confidence", "margin_sampling"]
        ]
        dict_tensors = {
            'input': dict_data['x'][0].cpu(),
            'target': dict_data['y'][0].cpu(),
            'pred': pred[0].detach().cpu(),
            'confidence': lc,
            'margin':
            -ms,  # minus sign is to draw smaller margin part brighter
            'entropy': ent
        }

        if self.n_pixels_by_us != 0:
            self.vis(
                dict_tensors,
                fp=
                f"{self.dir_checkpoints}/{self.nth_query}_query/{epoch}_val.png"
            )
        else:
            self.vis(dict_tensors,
                     fp=f"{self.dir_checkpoints}/fully_sup/{epoch}_val.png")
        return

    @staticmethod
    def _query(prob, query_strategy):
        # prob: b x n_classes x h x w
        if query_strategy == "least_confidence":
            query = 1.0 - prob.max(dim=1)[0]  # b x h x w

        elif query_strategy == "margin_sampling":
            query = prob.topk(k=2, dim=1).values  # b x k x h x w
            query = (query[:, 0, :, :] - query[:, 1, :, :]).abs()  # b x h x w

        elif query_strategy == "entropy":
            query = (-prob * torch.log(prob)).sum(dim=1)  # b x h x w

        elif query_strategy == "random":
            b, _, h, w = prob.shape
            query = torch.rand((b, h, w))

        else:
            raise ValueError
        return query

    def _reset_meters(self):
        self.running_loss.reset()
        self.running_score.reset()
Ejemplo n.º 6
0
class CoteachingTrainer(object):
    def __init__(self, config):
        # Config
        self._config = config
        self._epochs = config['epochs']
        self._step = config['step']
        self._logfile = config['log']
        self._n_classes = config['n_classes']

        # Network
        if ',' in config['net']:
            net_name_1, net_name_2 = config['net'].split(',')
        else:
            net_name_1, net_name_2 = config['net'], config['net']
        Net1, _ = make_network(net_name_1)
        Net2, _ = make_network(net_name_2)
        if self._step == 0:
            net1 = Net1(n_classes=self._n_classes,
                        pretrained=True,
                        use_two_step=False,
                        fc_init='He')
            net2 = Net2(n_classes=self._n_classes,
                        pretrained=True,
                        use_two_step=False,
                        fc_init='Xavier')
        elif self._step == 1:
            net1 = Net1(n_classes=self._n_classes,
                        pretrained=True,
                        use_two_step=True)
            net2 = Net2(n_classes=self._n_classes,
                        pretrained=True,
                        use_two_step=True)
        elif self._step == 2:
            net1 = Net1(n_classes=self._n_classes,
                        pretrained=False,
                        use_two_step=True)
            net2 = Net2(n_classes=self._n_classes,
                        pretrained=False,
                        use_two_step=True)
        else:
            raise AssertionError('step can only be 0, 1, 2')
        # Move network to cuda
        print('| Number of available GPUs : {} ({})'.format(
            torch.cuda.device_count(), os.environ["CUDA_VISIBLE_DEVICES"]))
        if torch.cuda.device_count() >= 1:
            self._net1 = nn.DataParallel(net1).cuda()
            self._net2 = nn.DataParallel(net2).cuda()
        else:
            raise AssertionError('CPU version is not implemented yet!')

        # Loss Criterion
        self.T_k = config['warmup_epochs']
        if self._step == 1:
            self.T_k = self._epochs

        # Optimizer
        if self._step == 1:
            params_to_optimize1 = self._net1.module.fc.parameters()
            params_to_optimize2 = self._net2.module.fc.parameters()
        else:
            params_to_optimize1 = self._net1.parameters()
            params_to_optimize2 = self._net2.parameters()
        self._optimizer1 = make_optimizer(params_to_optimize1,
                                          lr=config['lr'] / 2,
                                          weight_decay=config['weight_decay'],
                                          opt='SGD')
        self._optimizer2 = make_optimizer(params_to_optimize2,
                                          lr=config['lr'],
                                          weight_decay=config['weight_decay'],
                                          opt='SGD')

        self._scheduler1 = optim.lr_scheduler.CosineAnnealingLR(
            self._optimizer1, T_max=self._epochs, eta_min=0)
        self._scheduler2 = optim.lr_scheduler.CosineAnnealingLR(
            self._optimizer2, T_max=self._epochs, eta_min=0)

        # metrics
        self._train_loss1 = AverageMeter()
        self._train_loss2 = AverageMeter()
        self._train_accuracy1 = AverageMeter()
        self._train_accuracy2 = AverageMeter()
        self._epoch_train_time = AverageMeter()

        # Dataloader
        train_transform = make_transform(phase='train', output_size=448)
        test_transform = make_transform(phase='test', output_size=448)
        train_data = IndexedImageFolder(os.path.join(config['data_base'],
                                                     'train'),
                                        transform=train_transform)
        test_data = IndexedImageFolder(os.path.join(config['data_base'],
                                                    'val'),
                                       transform=test_transform)
        self._train_loader = data.DataLoader(train_data,
                                             batch_size=config['batch_size'],
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)
        self._test_loader = data.DataLoader(test_data,
                                            batch_size=16,
                                            shuffle=False,
                                            num_workers=4,
                                            pin_memory=True)
        print('|-----------------------------------------------------')
        print('| Number of samples in train set : {}'.format(len(train_data)))
        print('| Number of samples in test  set : {}'.format(len(test_data)))
        print('| Number of classes in train set : {}'.format(
            len(train_data.classes)))
        print('| Number of classes in test  set : {}'.format(
            len(test_data.classes)))
        print('|-----------------------------------------------------')
        assert len(train_data.classes) == self._n_classes and \
            len(test_data.classes) == self._n_classes, 'number of classes is wrong'

        # Resume or not
        if config['resume']:
            assert os.path.isfile(
                'checkpoint.pth'), 'no checkpoint.pth exists!'
            print('---> loading checkpoint.pth <---')
            checkpoint = torch.load('checkpoint.pth')
            assert self._step == checkpoint[
                'step'], 'step in checkpoint does not match step in argument'
            self._start_epoch = checkpoint['epoch']
            self._best_accuracy1 = checkpoint['best_accuracy1']
            self._best_accuracy2 = checkpoint['best_accuracy2']
            self._best_epoch1 = checkpoint['best_epoch1']
            self._best_epoch2 = checkpoint['best_epoch2']
            self._net1.load_state_dict(checkpoint['state_dict1'])
            self._net2.load_state_dict(checkpoint['state_dict2'])
            self._optimizer1.load_state_dict(checkpoint['optimizer1'])
            self._optimizer2.load_state_dict(checkpoint['optimizer2'])
            self._scheduler1.load_state_dict(checkpoint['scheduler1'])
            self._scheduler2.load_state_dict(checkpoint['scheduler2'])
            self.memory_pool1 = checkpoint['memory_pool1']
            self.memory_pool2 = checkpoint['memory_pool2']
        else:
            print('---> no checkpoint loaded <---')
            if self._step == 2:
                print('---> loading step1_best_epoch.pth <---')
                assert os.path.isfile('model/step1_best_epoch.pth')
                self._net1.load_state_dict(
                    torch.load('model/net1_step1_best_epoch.pth'))
                self._net2.load_state_dict(
                    torch.load('model/net2_step1_best_epoch.pth'))
            self._start_epoch = 0
            self._best_accuracy1 = 0.0
            self._best_accuracy2 = 0.0
            self._best_epoch1 = None
            self._best_epoch2 = None
            self.memory_pool1 = Queue(n_samples=len(train_data),
                                      memory_length=config['memory_length'])
            self.memory_pool2 = Queue(n_samples=len(train_data),
                                      memory_length=config['memory_length'])
        self._scheduler1.last_epoch = self._start_epoch
        self._scheduler2.last_epoch = self._start_epoch

    def train(self):
        console_header = 'Epoch\tTrain_Loss1\tTrain_Loss2\tTrain_Accuracy1\tTrain_Accuracy2\t' \
                         'Test_Accuracy1\tTest_Accuracy2\tEpoch_Runtime\tLearning_Rate1\tLearning_Rate2'
        print_to_console(console_header)
        print_to_logfile(self._logfile, console_header, init=True)

        for t in range(self._start_epoch, self._epochs):
            epoch_start = time.time()
            self._scheduler1.step(epoch=t)
            self._scheduler2.step(epoch=t)
            # reset average meters
            self._train_loss1.reset()
            self._train_loss2.reset()
            self._train_accuracy1.reset()
            self._train_accuracy2.reset()

            self._net1.train(True)
            self._net2.train(True)
            self.single_epoch_training(t)
            test_accuracy1 = evaluate(self._test_loader, self._net1)
            test_accuracy2 = evaluate(self._test_loader, self._net2)

            lr1 = get_lr_from_optimizer(self._optimizer1)
            lr2 = get_lr_from_optimizer(self._optimizer2)

            if test_accuracy1 > self._best_accuracy1:
                self._best_accuracy1 = test_accuracy1
                self._best_epoch1 = t + 1
                torch.save(
                    self._net1.state_dict(),
                    'model/net1_step{}_best_epoch.pth'.format(self._step))
            if test_accuracy2 > self._best_accuracy2:
                self._best_accuracy2 = test_accuracy2
                self._best_epoch2 = t + 1
                torch.save(
                    self._net2.state_dict(),
                    'model/net2_step{}_best_epoch.pth'.format(self._step))

            epoch_end = time.time()
            single_epoch_runtime = epoch_end - epoch_start
            # Logging
            console_content = '{:05d}\t{:10.4f}\t{:10.4f}\t{:14.4f}\t{:14.4f}\t' \
                              '{:13.4f}\t{:13.4f}\t{:13.2f}\t' \
                              '{:13.1e}\t{:13.1e}'.format(t + 1, self._train_loss1.avg, self._train_loss2.avg,
                                                          self._train_accuracy1.avg, self._train_accuracy2.avg,
                                                          test_accuracy1, test_accuracy2,
                                                          single_epoch_runtime, lr1, lr2)
            print_to_console(console_content)
            print_to_logfile(self._logfile, console_content, init=False)

            # save checkpoint
            save_checkpoint({
                'epoch': t + 1,
                'state_dict1': self._net1.state_dict(),
                'state_dict2': self._net2.state_dict(),
                'best_epoch1': self._best_epoch1,
                'best_epoch2': self._best_epoch2,
                'best_accuracy1': self._best_accuracy1,
                'best_accuracy2': self._best_accuracy2,
                'optimizer1': self._optimizer1.state_dict(),
                'optimizer2': self._optimizer2.state_dict(),
                'step': self._step,
                'scheduler1': self._scheduler1.state_dict(),
                'scheduler2': self._scheduler2.state_dict(),
                'memory_pool1': self.memory_pool1,
                'memory_pool2': self.memory_pool2,
            })

        console_content = 'Net1: Best at epoch {}, test accuracy is {}'.format(
            self._best_epoch1, self._best_accuracy1)
        print_to_console(console_content)
        console_content = 'Net2: Best at epoch {}, test accuracy is {}'.format(
            self._best_epoch2, self._best_accuracy2)
        print_to_console(console_content)

        # rename log file
        os.rename(
            self._logfile,
            self._logfile.replace(
                '.txt', '-{}_{}_{}_{:.4f}_{:.4f}.txt'.format(
                    self._config['net'], self._config['batch_size'],
                    self._config['lr'], self._best_accuracy1,
                    self._best_accuracy2)))

    def single_epoch_training(self, epoch, log_iter=True, log_freq=200):
        if epoch >= self.T_k:
            stats_log_path1 = 'stats/net1_drop_n_reuse_stats_epoch{:03d}.csv'.format(
                epoch + 1)
            stats_log_path2 = 'stats/net2_drop_n_reuse_stats_epoch{:03d}.csv'.format(
                epoch + 1)
            stats_log_header = 'clean_sample_num,reusable_sample_num,irrelevant_sample_num'
            print_to_logfile(stats_log_path1,
                             stats_log_header,
                             init=True,
                             end='\n')
            print_to_logfile(stats_log_path2,
                             stats_log_header,
                             init=True,
                             end='\n')

        for it, (x, y, indices) in enumerate(self._train_loader):
            s = time.time()

            x = x.cuda()
            y = y.cuda()
            self._optimizer1.zero_grad()
            self._optimizer2.zero_grad()
            logits1 = self._net1(x)
            logits2 = self._net2(x)
            losses1, ce_loss1, losses2, ce_loss2 = \
                cot_std_loss(logits1, logits2, y, indices, self.T_k, epoch,
                             self.memory_pool1, self.memory_pool1, eps=self._config['eps'])
            loss1 = losses1.mean()
            loss2 = losses2.mean()

            self.memory_pool1.update(indices=indices,
                                     losses=ce_loss1.detach().data.cpu(),
                                     scores=F.softmax(
                                         logits1, dim=1).detach().data.cpu(),
                                     labels=y.detach().data.cpu())
            self.memory_pool1.update(indices=indices,
                                     losses=ce_loss2.detach().data.cpu(),
                                     scores=F.softmax(
                                         logits2, dim=1).detach().data.cpu(),
                                     labels=y.detach().data.cpu())

            train_accuracy1 = accuracy(logits1, y, topk=(1, ))
            train_accuracy2 = accuracy(logits2, y, topk=(1, ))

            self._train_loss1.update(loss1.item(), losses1.size(0))
            self._train_loss2.update(loss2.item(), losses1.size(0))
            self._train_accuracy1.update(train_accuracy1[0], x.size(0))
            self._train_accuracy2.update(train_accuracy2[0], x.size(0))

            loss1.backward()
            loss2.backward()
            self._optimizer1.step()
            self._optimizer2.step()

            e = time.time()
            self._epoch_train_time.update(e - s, 1)
            if (log_iter and (it + 1) % log_freq == 0) or (it + 1 == len(
                    self._train_loader)):
                console_content = 'Epoch:[{:03d}/{:03d}]  Iter:[{:04d}/{:04d}]  ' \
                                  'Train Accuracy1 :[{:6.2f}]  Train Accuracy2 :[{:6.2f}]  ' \
                                  'Loss1:[{:4.4f}]  Loss2:[{:4.4f}]  ' \
                                  'Iter Runtime:[{:6.2f}]'.format(epoch + 1, self._epochs, it + 1,
                                                                  len(self._train_loader),
                                                                  self._train_accuracy1.avg, self._train_accuracy2.avg,
                                                                  self._train_loss1.avg, self._train_loss2.avg,
                                                                  self._epoch_train_time.avg)
                print_to_console(console_content)
Ejemplo n.º 7
0
class Trainer:
    """Pipeline to train a NN model using a certain dataset, both specified by an YML config."""
    @use_seed()
    def __init__(self, config_path, run_dir):
        self.config_path = coerce_to_path_and_check_exist(config_path)
        self.run_dir = coerce_to_path_and_create_dir(run_dir)
        self.logger = get_logger(self.run_dir, name="trainer")
        self.print_and_log_info(
            "Trainer initialisation: run directory is {}".format(run_dir))

        shutil.copy(self.config_path, self.run_dir)
        self.print_and_log_info("Config {} copied to run directory".format(
            self.config_path))

        with open(self.config_path) as fp:
            cfg = yaml.load(fp, Loader=yaml.FullLoader)

        if torch.cuda.is_available():
            type_device = "cuda"
            nb_device = torch.cuda.device_count()
            # XXX: set to False when input image sizes are not fixed
            torch.backends.cudnn.benchmark = cfg["training"].get(
                "cudnn_benchmark", True)

        else:
            type_device = "cpu"
            nb_device = None
        self.device = torch.device(type_device)
        self.print_and_log_info("Using {} device, nb_device is {}".format(
            type_device, nb_device))

        # Datasets and dataloaders
        self.dataset_kwargs = cfg["dataset"]
        self.dataset_name = self.dataset_kwargs.pop("name")
        train_dataset = get_dataset(self.dataset_name)("train",
                                                       **self.dataset_kwargs)
        val_dataset = get_dataset(self.dataset_name)("val",
                                                     **self.dataset_kwargs)
        self.restricted_labels = sorted(
            self.dataset_kwargs["restricted_labels"])
        self.n_classes = len(self.restricted_labels) + 1
        self.is_val_empty = len(val_dataset) == 0
        self.print_and_log_info("Dataset {} instantiated with {}".format(
            self.dataset_name, self.dataset_kwargs))
        self.print_and_log_info(
            "Found {} classes, {} train samples, {} val samples".format(
                self.n_classes, len(train_dataset), len(val_dataset)))

        self.batch_size = cfg["training"]["batch_size"]
        self.n_workers = cfg["training"]["n_workers"]
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=self.batch_size,
                                       num_workers=self.n_workers,
                                       shuffle=True)
        self.val_loader = DataLoader(val_dataset,
                                     batch_size=self.batch_size,
                                     num_workers=self.n_workers)
        self.print_and_log_info(
            "Dataloaders instantiated with batch_size={} and n_workers={}".
            format(self.batch_size, self.n_workers))

        self.n_batches = len(self.train_loader)
        self.n_iterations, self.n_epoches = cfg["training"].get(
            "n_iterations"), cfg["training"].get("n_epoches")
        assert not (self.n_iterations is not None
                    and self.n_epoches is not None)
        if self.n_iterations is not None:
            self.n_epoches = max(self.n_iterations // self.n_batches, 1)
        else:
            self.n_iterations = self.n_epoches * len(self.train_loader)

        # Model
        self.model_kwargs = cfg["model"]
        self.model_name = self.model_kwargs.pop("name")
        model = get_model(self.model_name)(self.n_classes,
                                           **self.model_kwargs).to(self.device)
        self.model = torch.nn.DataParallel(model,
                                           device_ids=range(
                                               torch.cuda.device_count()))
        self.print_and_log_info("Using model {} with kwargs {}".format(
            self.model_name, self.model_kwargs))
        self.print_and_log_info('Number of trainable parameters: {}'.format(
            f'{count_parameters(self.model):,}'))

        # Optimizer
        optimizer_params = cfg["training"]["optimizer"] or {}
        optimizer_name = optimizer_params.pop("name", None)
        self.optimizer = get_optimizer(optimizer_name)(model.parameters(),
                                                       **optimizer_params)
        self.print_and_log_info("Using optimizer {} with kwargs {}".format(
            optimizer_name, optimizer_params))

        # Scheduler
        scheduler_params = cfg["training"].get("scheduler", {}) or {}
        scheduler_name = scheduler_params.pop("name", None)
        self.scheduler_update_range = scheduler_params.pop(
            "update_range", "epoch")
        assert self.scheduler_update_range in ["epoch", "batch"]
        if scheduler_name == "multi_step" and isinstance(
                scheduler_params["milestones"][0], float):
            n_tot = self.n_epoches if self.scheduler_update_range == "epoch" else self.n_iterations
            scheduler_params["milestones"] = [
                round(m * n_tot) for m in scheduler_params["milestones"]
            ]
        self.scheduler = get_scheduler(scheduler_name)(self.optimizer,
                                                       **scheduler_params)
        self.cur_lr = -1
        self.print_and_log_info("Using scheduler {} with parameters {}".format(
            scheduler_name, scheduler_params))

        # Loss
        loss_name = cfg["training"]["loss"]
        self.criterion = get_loss(loss_name)()
        self.print_and_log_info("Using loss {}".format(self.criterion))

        # Pretrained / Resume
        checkpoint_path = cfg["training"].get("pretrained")
        checkpoint_path_resume = cfg["training"].get("resume")
        assert not (checkpoint_path is not None
                    and checkpoint_path_resume is not None)
        if checkpoint_path is not None:
            self.load_from_tag(checkpoint_path)
        elif checkpoint_path_resume is not None:
            self.load_from_tag(checkpoint_path_resume, resume=True)
        else:
            self.start_epoch, self.start_batch = 1, 1

        # Train metrics
        train_iter_interval = cfg["training"].get(
            "train_stat_interval", self.n_epoches * self.n_batches // 200)
        self.train_stat_interval = train_iter_interval
        self.train_time = AverageMeter()
        self.train_loss = AverageMeter()
        self.train_metrics_path = self.run_dir / TRAIN_METRICS_FILE
        with open(self.train_metrics_path, mode="w") as f:
            f.write(
                "iteration\tepoch\tbatch\ttrain_loss\ttrain_time_per_img\n")

        # Val metrics
        val_iter_interval = cfg["training"].get(
            "val_stat_interval", self.n_epoches * self.n_batches // 100)
        self.val_stat_interval = val_iter_interval
        self.val_loss = AverageMeter()
        self.val_metrics = RunningMetrics(self.restricted_labels)
        self.val_current_score = None
        self.val_metrics_path = self.run_dir / VAL_METRICS_FILE
        with open(self.val_metrics_path, mode="w") as f:
            f.write("iteration\tepoch\tbatch\tval_loss\t" +
                    "\t".join(self.val_metrics.names) + "\n")

    def print_and_log_info(self, string):
        print_info(string)
        self.logger.info(string)

    def load_from_tag(self, tag, resume=False):
        self.print_and_log_info("Loading model from run {}".format(tag))
        path = coerce_to_path_and_check_exist(MODELS_PATH / tag / MODEL_FILE)
        checkpoint = torch.load(path, map_location=self.device)
        try:
            self.model.load_state_dict(checkpoint["model_state"])
        except RuntimeError:
            state = safe_model_state_dict(checkpoint["model_state"])
            self.model.module.load_state_dict(state)
        self.start_epoch, self.start_batch = 1, 1
        if resume:
            self.start_epoch, self.start_batch = checkpoint[
                "epoch"], checkpoint.get("batch", 0) + 1
            self.optimizer.load_state_dict(checkpoint["optimizer_state"])
            self.scheduler.load_state_dict(checkpoint["scheduler_state"])
            self.cur_lr = self.scheduler.get_lr()
        self.print_and_log_info(
            "Checkpoint loaded at epoch {}, batch {}".format(
                self.start_epoch, self.start_batch - 1))

    def _create_external_val_loader_and_monitor(self, dataset_name):
        val_dataset = get_dataset(dataset_name)(split="val",
                                                **self.dataset_kwargs)
        val_loader = DataLoader(val_dataset,
                                batch_size=self.batch_size,
                                num_workers=self.n_workers)
        self.print_and_log_info(
            "External {} validation dataset instantiated with kwargs {}: {} samples"
            .format(dataset_name, self.dataset_kwargs, len(val_dataset)))
        monitor = {}
        monitor["name"] = dataset_name
        monitor["loss"] = AverageMeter()
        monitor["metrics"] = RunningMetrics(val_dataset.restricted_labels,
                                            val_dataset.metric_labels)
        monitor["metrics_path"] = self.run_dir / "{}_metrics.tsv".format(
            dataset_name)
        with open(monitor["metrics_path"], mode="w") as f:
            f.write("iteration\tepoch\tbatch\t{}_loss\t".format(dataset_name) +
                    "\t".join(monitor["metrics"].names) + "\n")

        return val_loader, monitor

    @property
    def score_name(self):
        return self.val_metrics.score_name

    def print_memory_usage(self, prefix):
        usage = {}
        for attr in [
                "memory_allocated", "max_memory_allocated", "memory_cached",
                "max_memory_cached"
        ]:
            usage[attr] = getattr(torch.cuda, attr)() * 0.000001
        self.print_and_log_info("{}:\t{}".format(
            prefix, " / ".join(
                ["{}: {:.0f}MiB".format(k, v) for k, v in usage.items()])))

    @use_seed()
    def run(self):
        self.model.train()
        cur_iter = (self.start_epoch -
                    1) * self.n_batches + self.start_batch - 1
        prev_train_stat_iter, prev_val_stat_iter = cur_iter, cur_iter
        for epoch in range(self.start_epoch, self.n_epoches + 1):
            batch_start = self.start_batch if epoch == self.start_epoch else 1
            if self.scheduler_update_range == "epoch":
                if batch_start == 1:
                    self.update_scheduler(epoch, batch=batch_start)

            for batch, (images, labels) in enumerate(self.train_loader,
                                                     start=1):
                if batch < batch_start:
                    continue
                cur_iter += 1
                if cur_iter > self.n_iterations:
                    break

                if self.scheduler_update_range == "batch":
                    self.update_scheduler(epoch, batch=batch)

                self.single_train_batch_run(images, labels)
                if (cur_iter -
                        prev_train_stat_iter) >= self.train_stat_interval:
                    prev_train_stat_iter = cur_iter
                    self.log_train_metrics(cur_iter, epoch, batch)

                if (cur_iter - prev_val_stat_iter) >= self.val_stat_interval:
                    prev_val_stat_iter = cur_iter
                    self.run_val()
                    self.log_val_metrics(cur_iter, epoch, batch)
                    self.save(epoch=epoch, batch=batch)

        self.print_and_log_info("Training run is over")

    def update_scheduler(self, epoch, batch):
        self.scheduler.step()
        lr = self.scheduler.get_lr()
        if lr != self.cur_lr:
            self.cur_lr = lr
            msg = PRINT_LR_UPD_FMT.format(epoch, self.n_epoches, batch,
                                          self.n_batches, lr)
            self.print_and_log_info(msg)

    def single_train_batch_run(self, images, labels):
        start_time = time.time()
        images, labels = images.to(self.device), labels.to(self.device)

        self.optimizer.zero_grad()
        loss = self.criterion(self.model(images), labels)
        loss.backward()
        self.optimizer.step()

        self.train_loss.update(loss.item())
        self.train_time.update((time.time() - start_time) / self.batch_size)

    def log_train_metrics(self, cur_iter, epoch, batch):
        stat = PRINT_TRAIN_STAT_FMT.format(epoch, self.n_epoches, batch,
                                           self.n_batches, self.train_loss.avg,
                                           self.train_time.avg)
        self.print_and_log_info(stat)

        with open(self.train_metrics_path, mode="a") as f:
            f.write("{}\t{}\t{}\t{:.4f}\t{:.4f}\n".format(
                cur_iter, epoch, batch, self.train_loss.avg,
                self.train_time.avg))

        self.train_loss.reset()
        self.train_time.reset()

    def run_val(self):
        self.model.eval()
        with torch.no_grad():
            for images, labels in self.val_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)

                pred = outputs.data.max(1)[1].cpu().numpy()
                if images.size() == labels.size():
                    gt = labels.data.max(1)[1].cpu().numpy()
                else:
                    gt = labels.cpu().numpy()

                self.val_metrics.update(gt, pred)
                self.val_loss.update(loss.item())

        self.model.train()

    def log_val_metrics(self, cur_iter, epoch, batch):
        stat = PRINT_VAL_STAT_FMT.format(epoch, self.n_epoches, batch,
                                         self.n_batches, self.val_loss.avg)
        self.print_and_log_info(stat)

        metrics = self.val_metrics.get()
        self.print_and_log_info(
            "Val metrics: " +
            ", ".join(["{} = {:.4f}".format(k, v)
                       for k, v in metrics.items()]))

        with open(self.val_metrics_path, mode="a") as f:
            f.write("{}\t{}\t{}\t{:.4f}\t".format(cur_iter, epoch, batch,
                                                  self.val_loss.avg) +
                    "\t".join(map("{:.4f}".format, metrics.values())) + "\n")

        self.val_current_score = metrics[self.score_name]
        self.val_loss.reset()
        self.val_metrics.reset()

    def save(self, epoch, batch):
        state = {
            "epoch": epoch,
            "batch": batch,
            "model_name": self.model_name,
            "model_kwargs": self.model_kwargs,
            "model_state": self.model.state_dict(),
            "n_classes": self.n_classes,
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict(),
            "score": self.val_current_score,
            "train_resolution": self.dataset_kwargs["img_size"],
            "restricted_labels": self.dataset_kwargs["restricted_labels"],
            "normalize": self.dataset_kwargs["normalize"],
        }

        save_path = self.run_dir / MODEL_FILE
        torch.save(state, save_path)
        self.print_and_log_info("Model saved at {}".format(save_path))
Ejemplo n.º 8
0
class Trainer(object):
    def __init__(self, config):
        # Config
        self._config = config
        self._epochs = config['epochs']
        self._step = config['step']
        self._logfile = config['log']
        self._n_classes = config['n_classes']

        # Network
        Net, feature_dim = make_network(config['net'])

        if self._step == 0:
            net = Net(n_classes=self._n_classes, pretrained=True, use_two_step=False)
        elif self._step == 1:
            net = Net(n_classes=self._n_classes, pretrained=True, use_two_step=True)
        elif self._step == 2:
            net = Net(n_classes=self._n_classes, pretrained=False, use_two_step=True)
        else:
            raise AssertionError('step can only be 0, 1, 2')
        # Move network to cuda
        print('| Number of available GPUs : {} ({})'.format(torch.cuda.device_count(),
                                                            os.environ["CUDA_VISIBLE_DEVICES"]))
        if torch.cuda.device_count() >= 1:
            self._net = nn.DataParallel(net).cuda()
        else:
            raise AssertionError('CPU version is not implemented yet!')

        # Loss Criterion
        self.T_k = config['warmup_epochs']
        if self._step == 1:
            self.T_k = self._epochs

        # Optimizer
        if self._step == 1:
            params_to_optimize = self._net.module.fc.parameters()
        else:
            params_to_optimize = self._net.parameters()
        self._optimizer = make_optimizer(params_to_optimize, lr=config['lr'], weight_decay=config['weight_decay'],
                                         opt='SGD')

        self._scheduler = optim.lr_scheduler.CosineAnnealingLR(self._optimizer, T_max=self._epochs, eta_min=0)
        # metrics
        self._train_loss = AverageMeter()
        self._train_accuracy = AverageMeter()
        self._epoch_train_time = AverageMeter()

        # Dataloader
        train_transform = make_transform(phase='train', output_size=448)
        test_transform = make_transform(phase='test', output_size=448)
        train_data = IndexedImageFolder(os.path.join(config['data_base'], 'train'), transform=train_transform)
        test_data = IndexedImageFolder(os.path.join(config['data_base'], 'val'), transform=test_transform)
        self._train_loader = data.DataLoader(train_data, batch_size=config['batch_size'], shuffle=True, num_workers=4,
                                             pin_memory=True)
        self._test_loader = data.DataLoader(test_data, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
        print('|-----------------------------------------------------')
        print('| Number of samples in train set : {}'.format(len(train_data)))
        print('| Number of samples in test  set : {}'.format(len(test_data)))
        print('| Number of classes in train set : {}'.format(len(train_data.classes)))
        print('| Number of classes in test  set : {}'.format(len(test_data.classes)))
        print('|-----------------------------------------------------')
        assert len(train_data.classes) == self._n_classes and \
            len(test_data.classes) == self._n_classes, 'number of classes is wrong'

        # Resume or not
        if config['resume']:
            assert os.path.isfile('checkpoint.pth'), 'no checkpoint.pth exists!'
            print('---> loading checkpoint.pth <---')
            checkpoint = torch.load('checkpoint.pth')
            assert self._step == checkpoint['step'], 'step in checkpoint does not match step in argument'
            self._start_epoch = checkpoint['epoch']
            self._best_accuracy = checkpoint['best_accuracy']
            self._best_epoch = checkpoint['best_epoch']
            self._net.load_state_dict(checkpoint['state_dict'])
            self._optimizer.load_state_dict(checkpoint['optimizer'])
            self._scheduler.load_state_dict(checkpoint['scheduler'])
            self.memory_pool = checkpoint['memory_pool']
        else:
            print('---> no checkpoint loaded <---')
            if self._step == 2:
                print('---> loading step1_best_epoch.pth <---')
                assert os.path.isfile('model/step1_best_epoch.pth')
                self._net.load_state_dict(torch.load('model/step1_best_epoch.pth'))
            self._start_epoch = 0
            self._best_accuracy = 0.0
            self._best_epoch = None
            self.memory_pool = Queue(n_samples=len(train_data), memory_length=config['memory_length'])
        self._scheduler.last_epoch = self._start_epoch

    def train(self):
        console_header = 'Epoch\tTrain_Loss\tTrain_Accuracy\tTest_Accuracy\tEpoch_Runtime\tLearning_Rate'
        print_to_console(console_header)
        print_to_logfile(self._logfile, console_header, init=True)

        for t in range(self._start_epoch, self._epochs):
            epoch_start = time.time()
            self._scheduler.step(epoch=t)
            # reset average meters
            self._train_loss.reset()
            self._train_accuracy.reset()

            self._net.train(True)
            self.single_epoch_training(t)
            test_accuracy = evaluate(self._test_loader, self._net)

            lr = get_lr_from_optimizer(self._optimizer)

            if test_accuracy > self._best_accuracy:
                self._best_accuracy = test_accuracy
                self._best_epoch = t + 1
                torch.save(self._net.state_dict(), 'model/step{}_best_epoch.pth'.format(self._step))
                # print('*', end='')
            epoch_end = time.time()
            single_epoch_runtime = epoch_end - epoch_start
            # Logging
            console_content = '{:05d}\t{:10.4f}\t{:14.4f}\t{:13.4f}\t{:13.2f}\t{:13.1e}'.format(
                t + 1, self._train_loss.avg, self._train_accuracy.avg, test_accuracy, single_epoch_runtime, lr)
            print_to_console(console_content)
            print_to_logfile(self._logfile, console_content, init=False)

            # save checkpoint
            save_checkpoint({
                'epoch': t + 1,
                'state_dict': self._net.state_dict(),
                'best_epoch': self._best_epoch,
                'best_accuracy': self._best_accuracy,
                'optimizer': self._optimizer.state_dict(),
                'step': self._step,
                'scheduler': self._scheduler.state_dict(),
                'memory_pool': self.memory_pool,
            })

        console_content = 'Best at epoch {}, test accuracy is {}'.format(self._best_epoch, self._best_accuracy)
        print_to_console(console_content)

        # rename log file, stats files and model
        os.rename(self._logfile, self._logfile.replace('.txt', '-{}_{}_{}_{:.4f}.txt'.format(
            self._config['net'], self._config['batch_size'], self._config['lr'], self._best_accuracy)))

    def single_epoch_training(self, epoch, log_iter=True, log_freq=100):
        if epoch >= self.T_k:
            stats_log_path = 'stats/drop_n_reuse_stats_epoch{:03d}.csv'.format(epoch+1)
            stats_log_header = 'clean_sample_num,reusable_sample_num,irrelevant_sample_num'
            print_to_logfile(stats_log_path, stats_log_header, init=True, end='\n')
        for it, (x, y, indices) in enumerate(self._train_loader):
            s = time.time()

            x = x.cuda()
            y = y.cuda()
            self._optimizer.zero_grad()
            logits = self._net(x)
            losses, ce_loss = std_loss(logits, y, indices, self.T_k, epoch, self.memory_pool,
                                       eps=self._config['eps'])
            loss = losses.mean()

            self.memory_pool.update(indices=indices, losses=ce_loss.detach().data.cpu(),
                                    scores=F.softmax(logits, dim=1).detach().data.cpu(),
                                    labels=y.detach().data.cpu())

            train_accuracy = accuracy(logits, y, topk=(1,))

            self._train_loss.update(loss.item(), x.size(0))
            self._train_accuracy.update(train_accuracy[0], x.size(0))

            loss.backward()
            self._optimizer.step()

            e = time.time()
            self._epoch_train_time.update(e-s, 1)
            if (log_iter and (it+1) % log_freq == 0) or (it+1 == len(self._train_loader)):
                console_content = 'Epoch:[{0:03d}/{1:03d}]  Iter:[{2:04d}/{3:04d}]  ' \
                                  'Train Accuracy :[{4:6.2f}]  Loss:[{5:4.4f}]  ' \
                                  'Iter Runtime:[{6:6.2f}]'.format(epoch + 1, self._epochs, it + 1,
                                                                   len(self._train_loader),
                                                                   self._train_accuracy.avg,
                                                                   self._train_loss.avg, self._epoch_train_time.avg)
                print_to_console(console_content)
Ejemplo n.º 9
0
def main(cfg, distributed=True):
    if distributed:
        # DPP 1
        dist.init_process_group('nccl')
        # DPP 2
        local_rank = dist.get_rank()
        print(local_rank)
        torch.cuda.set_device(local_rank)
        device = torch.device('cuda', local_rank)
    else:
        device = torch.device("cuda:0")
        local_rank = 0

    ###################################################
    mode = cfg.mode
    n_class = cfg.n_class
    model_path = cfg.model_path  # save model
    log_path = cfg.log_path
    output_path = cfg.output_path

    if local_rank == 0:
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        if not os.path.exists(log_path):
            os.makedirs(log_path)
        if not os.path.exists(output_path):
            os.makedirs(output_path)

    task_name = cfg.task_name
    print(task_name)

    ###################################
    print("preparing datasets and dataloaders......")
    batch_size = cfg.batch_size
    sub_batch_size = cfg.sub_batch_size
    size_g = (cfg.size_g, cfg.size_g)
    size_p = (cfg.size_p, cfg.size_p)
    num_workers = cfg.num_workers
    trainset_cfg = cfg.trainset_cfg
    valset_cfg = cfg.valset_cfg

    data_time = AverageMeter("DataTime", ':3.3f')
    batch_time = AverageMeter("BatchTime", ':3.3f')

    transformer_train = TransformerSegGL(crop_size=cfg.crop_size)
    dataset_train = OralDatasetSeg(
        trainset_cfg["img_dir"],
        trainset_cfg["mask_dir"],
        trainset_cfg["meta_file"],
        label=trainset_cfg["label"],
        transform=transformer_train,
    )
    if distributed:
        sampler_train = DistributedSampler(dataset_train, shuffle=True)
        dataloader_train = DataLoader(dataset_train,
                                      num_workers=num_workers,
                                      batch_size=batch_size,
                                      collate_fn=collateGL,
                                      sampler=sampler_train,
                                      pin_memory=True)
    else:
        dataloader_train = DataLoader(dataset_train,
                                      num_workers=num_workers,
                                      batch_size=batch_size,
                                      collate_fn=collateGL,
                                      shuffle=True,
                                      pin_memory=True)
    transformer_val = TransformerSegGLVal()
    dataset_val = OralDatasetSeg(valset_cfg["img_dir"],
                                 valset_cfg["mask_dir"],
                                 valset_cfg["meta_file"],
                                 label=valset_cfg["label"],
                                 transform=transformer_val)
    dataloader_val = DataLoader(dataset_val,
                                num_workers=2,
                                batch_size=batch_size,
                                collate_fn=collateGL,
                                shuffle=False,
                                pin_memory=True)

    ###################################
    print("creating models......")
    path_g = cfg.path_g
    path_g2l = cfg.path_g2l
    path_l2g = cfg.path_l2g
    model = GLNet(n_class, cfg.encoder, **cfg.model_cfg)
    if mode == 3:
        global_fixed = GLNet(n_class, cfg.encoder, **cfg.model_cfg)
    else:
        global_fixed = None
    model, global_fixed = create_model_load_weights(model,
                                                    global_fixed,
                                                    device,
                                                    mode=mode,
                                                    distributed=distributed,
                                                    local_rank=local_rank,
                                                    evaluation=False,
                                                    path_g=path_g,
                                                    path_g2l=path_g2l,
                                                    path_l2g=path_l2g)

    ###################################
    num_epochs = cfg.num_epochs
    learning_rate = cfg.lr

    optimizer = get_optimizer(model, mode, learning_rate=learning_rate)
    scheduler = LR_Scheduler(cfg.scheduler, learning_rate, num_epochs,
                             len(dataloader_train))
    ##################################
    if cfg.loss == "ce":
        criterion = nn.CrossEntropyLoss(reduction='mean')
    elif cfg.loss == "sce":
        criterion = SymmetricCrossEntropyLoss(alpha=cfg.alpha,
                                              beta=cfg.beta,
                                              num_classes=cfg.n_class)
        # criterion4 = NormalizedSymmetricCrossEntropyLoss(alpha=cfg.alpha, beta=cfg.beta, num_classes=cfg.n_class)
    elif cfg.loss == "focal":
        criterion = FocalLoss(gamma=3)
    elif cfg.loss == "ce-dice":
        criterion = nn.CrossEntropyLoss(reduction='mean')
        # criterion2 =

    #######################################
    trainer = Trainer(criterion, optimizer, n_class, size_g, size_p,
                      sub_batch_size, mode, cfg.lamb_fmreg)
    evaluator = Evaluator(n_class, size_g, size_p, sub_batch_size, mode)
    evaluation = cfg.evaluation
    val_vis = cfg.val_vis

    best_pred = 0.0
    print("start training......")

    # log
    if local_rank == 0:
        f_log = open(os.path.join(log_path, ".log"), 'w')
        log = task_name + '\n'
        for k, v in cfg.__dict__.items():
            log += str(k) + ' = ' + str(v) + '\n'
        f_log.write(log)
        f_log.flush()
    # writer
    if local_rank == 0:
        writer = SummaryWriter(log_dir=log_path)
    writer_info = {}

    for epoch in range(num_epochs):
        trainer.set_train(model)
        optimizer.zero_grad()
        tbar = tqdm(dataloader_train)
        train_loss = 0

        start_time = time.time()
        for i_batch, sample in enumerate(tbar):
            data_time.update(time.time() - start_time)
            scheduler(optimizer, i_batch, epoch, best_pred)
            # loss = trainer.train(sample, model)
            loss = trainer.train(sample, model, global_fixed)
            train_loss += loss.item()
            score_train, score_train_global, score_train_local = trainer.get_scores(
            )

            batch_time.update(time.time() - start_time)
            start_time = time.time()

            if i_batch % 20 == 0 and local_rank == 0:
                if mode == 1:
                    tbar.set_description(
                        'Train loss: %.4f;global mIoU: %.4f; data time: %.2f; batch time: %.2f'
                        % (train_loss /
                           (i_batch + 1), score_train_global["iou_mean"],
                           data_time.avg, batch_time.avg))
                elif mode == 2:
                    tbar.set_description(
                        'Train loss: %.4f;agg mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f'
                        % (train_loss / (i_batch + 1), score_train["iou_mean"],
                           score_train_local["iou_mean"], data_time.avg,
                           batch_time.avg))
                else:
                    tbar.set_description(
                        'Train loss: %.4f;agg mIoU: %.4f; global mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f'
                        % (train_loss / (i_batch + 1), score_train["iou_mean"],
                           score_train_global["iouu_mean"],
                           score_train_local["iou_mean"], data_time.avg,
                           batch_time.avg))

        score_train, score_train_global, score_train_local = trainer.get_scores(
        )
        trainer.reset_metrics()
        data_time.reset()
        batch_time.reset()

        if evaluation and epoch % 1 == 0 and local_rank == 0:
            with torch.no_grad():
                model.eval()
                print("evaluating...")
                tbar = tqdm(dataloader_val)

                start_time = time.time()
                for i_batch, sample in enumerate(tbar):
                    data_time.update(time.time() - start_time)
                    predictions, predictions_global, predictions_local = evaluator.eval_test(
                        sample, model, global_fixed)
                    score_val, score_val_global, score_val_local = evaluator.get_scores(
                    )

                    batch_time.update(time.time() - start_time)
                    if i_batch % 20 == 0 and local_rank == 0:
                        if mode == 1:
                            tbar.set_description(
                                'global mIoU: %.4f; data time: %.2f; batch time: %.2f'
                                % (score_val_global["iou_mean"], data_time.avg,
                                   batch_time.avg))
                        elif mode == 2:
                            tbar.set_description(
                                'agg mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f'
                                % (score_val["iou_mean"],
                                   score_val_local["iou_mean"], data_time.avg,
                                   batch_time.avg))
                        else:
                            tbar.set_description(
                                'agg mIoU: %.4f; global mIoU: %.4f; local mIoU: %.4f; data time: %.2f; batch time: %.2f'
                                % (score_val["iou_mean"],
                                   score_val_global["iou_mean"],
                                   score_val_local["iou_mean"], data_time.avg,
                                   batch_time.avg))

                    if val_vis and i_batch == len(
                            tbar) // 2:  # val set result visualize
                        mask_rgb = class_to_RGB(np.array(sample['mask'][1]))
                        mask_rgb = ToTensor()(mask_rgb)
                        writer_info.update(mask=mask_rgb,
                                           prediction_global=ToTensor()(
                                               class_to_RGB(
                                                   predictions_global[1])))
                        if mode == 2 or mode == 3:
                            writer.update(prediction=ToTensor()(class_to_RGB(
                                predictions[1])),
                                          prediction_local=ToTensor()(
                                              class_to_RGB(
                                                  predictions_local[1])))

                    start_time = time.time()

                data_time.reset()
                batch_time.reset()
                score_val, score_val_global, score_val_local = evaluator.get_scores(
                )
                evaluator.reset_metrics()

                # save model
                best_pred = save_ckpt_model(model, cfg, score_val,
                                            score_val_global, best_pred, epoch)
                # log
                update_log(
                    f_log, cfg,
                    [score_train, score_train_global, score_train_local],
                    [score_val, score_val_global, score_val_local], epoch)
                # writer
                if mode == 1:
                    writer_info.update(
                        loss=train_loss / len(tbar),
                        lr=optimizer.param_groups[0]['lr'],
                        mIOU={
                            "train": score_train_global["iou_mean"],
                            "val": score_val_global["iou_mean"],
                        },
                        global_mIOU={
                            "train": score_train_global["iou_mean"],
                            "val": score_val_global["iou_mean"],
                        },
                        mucosa_iou={
                            "train": score_train_global["iou"][2],
                            "val": score_val_global["iou"][2],
                        },
                        tumor_iou={
                            "train": score_train_global["iou"][3],
                            "val": score_val_global["iou"][3],
                        },
                    )
                else:
                    writer_info.update(
                        loss=train_loss / len(tbar),
                        lr=optimizer.param_groups[0]['lr'],
                        mIOU={
                            "train": score_train["iou_mean"],
                            "val": score_val["iou_mean"],
                        },
                        global_mIOU={
                            "train": score_train_global["iou_mean"],
                            "val": score_val_global["iou_mean"],
                        },
                        local_mIOU={
                            "train": score_train_local["iou_mean"],
                            "val": score_val_local["iou_mean"],
                        },
                        mucosa_iou={
                            "train": score_train["iou"][2],
                            "val": score_val["iou"][2],
                        },
                        tumor_iou={
                            "train": score_train["iou"][3],
                            "val": score_val["iou"][3],
                        },
                    )

                update_writer(writer, writer_info, epoch)
    if local_rank == 0:
        f_log.close()
Ejemplo n.º 10
0
def main():
    NUM_POINT = 20000
    opt = OptInit().initialize()
    opt.num_worker = 32
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpuNum

    opt.printer.info('===> Creating dataloader ...')

    train_dataset = BigredDataset(root = opt.train_path,
                                 is_train=True,
                                 is_validation=False,
                                 is_test=False,
                                 num_channel=opt.num_channel,
                                 pre_transform=T.NormalizeScale()
                                 )
    train_loader = DenseDataLoader(train_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_worker)

    validation_dataset = BigredDataset(root = opt.train_path,
                                 is_train=False,
                                 is_validation=True,
                                 is_test=False,
                                 num_channel=opt.num_channel,
                                 pre_transform=T.NormalizeScale()
                                 )
    validation_loader = DenseDataLoader(validation_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_worker)

    opt.printer.info('===> computing Labelweight ...')

    labelweights = np.zeros(2)
    labelweights, _ = np.histogram(train_dataset.data.y.numpy(), range(3))
    labelweights = labelweights.astype(np.float32)
    labelweights = labelweights / np.sum(labelweights)
    labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
    weights = torch.Tensor(labelweights).cuda()
    print("labelweights", weights)

    opt.n_classes = train_loader.dataset.num_classes

    opt.printer.info('===> Loading the network ...')

    opt.best_value = 0
    print("GPU:",opt.device)
    model = DenseDeepGCN(opt).to(opt.device)
    if opt.multi_gpus:
        model = DataParallel(DenseDeepGCN(opt)).to(device=opt.device)
    opt.printer.info('===> loading pre-trained ...')
    # model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase)

    opt.printer.info('===> Init the optimizer ...')
    criterion = torch.nn.CrossEntropyLoss(weight = weights).to(opt.device)
    # criterion_test = torch.nn.CrossEntropyLoss(weight = weights)

    if opt.optim.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    elif opt.optim.lower() == 'radam':
        optimizer = optim.RAdam(model.parameters(), lr=opt.lr)
    else:
        raise NotImplementedError('opt.optim is not supported')
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_adjust_freq, opt.lr_decay_rate)
    # optimizer, scheduler, opt.lr = load_pretrained_optimizer(opt.pretrained_model, optimizer, scheduler, opt.lr)

    opt.printer.info('===> Init Metric ...')
    opt.losses = AverageMeter()
    # opt.test_metric = miou
    opt.test_values = AverageMeter()
    opt.test_value = 0.

    opt.printer.info('===> start training ...')
    writer = SummaryWriter()
    writer_test = SummaryWriter()
    counter_test = 0
    counter_play = 0
    start_epoch = 0
    mean_miou = AverageMeter()
    mean_loss =  AverageMeter()
    mean_acc =  AverageMeter()
    best_value = 0
    for epoch in range(start_epoch, opt.total_epochs):
        opt.epoch += 1
        model.train()
        total_seen_class = [0 for _ in range(opt.n_classes)]
        total_correct_class = [0 for _ in range(opt.n_classes)]
        total_iou_deno_class = [0 for _ in range(opt.n_classes)]
        ave_mIoU = 0
        total_correct = 0
        total_seen = 0
        loss_sum = 0

        mean_miou.reset()
        mean_loss.reset()
        mean_acc.reset()


        for i, data in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9):
            # if i % 50 == 0:
            opt.iter += 1
            if not opt.multi_gpus:
                data = data.to(opt.device)
            target = data.y
            batch_label2 = target.cpu().data.numpy()
            inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
            inputs = inputs[:, :opt.num_channel, :, :]
            gt = data.y.to(opt.device)
            # ------------------ zero, output, loss
            optimizer.zero_grad()
            out = model(inputs)

            loss = criterion(out, gt)
            #pdb.set_trace()

            # ------------------ optimization
            loss.backward()
            optimizer.step()

            seg_pred= out.transpose(2,1)

            pred_val = seg_pred.contiguous().cpu().data.numpy()
            seg_pred = seg_pred.contiguous().view(-1, opt.n_classes)
            #pdb.set_trace()
            pred_val = np.argmax(pred_val, 2)
            batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy()
            target = target.view(-1, 1)[:, 0]
            pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
            correct = np.sum(pred_choice == batch_label)

            total_correct += correct
            total_seen += (opt.batch_size *NUM_POINT)
            loss_sum += loss.item()

            current_seen_class = [0 for _ in range(opt.n_classes)]
            current_correct_class = [0 for _ in range(opt.n_classes)]
            current_iou_deno_class = [0 for _ in range(opt.n_classes)]
            #pdb.set_trace()

            for l in range(opt.n_classes):
                #pdb.set_trace()
                total_seen_class[l] += np.sum((batch_label2 == l))
                total_correct_class[l] += np.sum((pred_val == l) & (batch_label2 == l))
                total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label2 == l)))
                current_seen_class[l] = np.sum((batch_label2 == l))
                current_correct_class[l] = np.sum((pred_val == l) & (batch_label2 == l))
                current_iou_deno_class[l] = np.sum(((pred_val == l) | (batch_label2 == l)))

            #pdb.set_trace()
            writer.add_scalar('training_loss', loss.item(), counter_play)
            writer.add_scalar('training_accuracy', correct / float(opt.batch_size * NUM_POINT), counter_play)
            m_iou = np.mean(np.array(current_correct_class) / (np.array(current_iou_deno_class, dtype=np.float) + 1e-6))
            writer.add_scalar('training_mIoU', m_iou, counter_play)
            ave_mIoU = np.mean(np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6))

            # print("training_loss:",loss.item())
            # print('training_accuracy:',correct / float(opt.batch_size * NUM_POINT))
            # print('training_mIoU:',m_iou)

            mean_miou.update(m_iou)
            mean_loss.update(loss.item())
            mean_acc.update(correct / float(opt.batch_size * NUM_POINT))

            counter_play = counter_play + 1

        train_mIoU = mean_miou.avg
        train_macc = mean_acc.avg
        train_mloss = mean_loss.avg

        print('Epoch: %d, Training point avg class IoU: %f' % (epoch,train_mIoU))
        print('Epoch: %d, Training mean loss: %f' %(epoch, train_mloss))
        print('Epoch: %d, Training accuracy: %f' %(epoch, train_macc))

        mean_miou.reset()
        mean_loss.reset()
        mean_acc.reset()

        print('validation_loader')

        model.eval()
        with torch.no_grad():
            for i, data in tqdm(enumerate(validation_loader), total=len(validation_loader), smoothing=0.9):
                # if i % 50 ==0:
                if not opt.multi_gpus:
                    data = data.to(opt.device)

                target = data.y
                batch_label2 = target.cpu().data.numpy()

                inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
                inputs = inputs[:, :opt.num_channel, :, :]
                gt = data.y.to(opt.device)
                out = model(inputs)
                loss = criterion(out, gt)
                #pdb.set_trace()

                seg_pred = out.transpose(2, 1)
                pred_val = seg_pred.contiguous().cpu().data.numpy()
                seg_pred = seg_pred.contiguous().view(-1, opt.n_classes)
                pred_val = np.argmax(pred_val, 2)
                batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy()
                target = target.view(-1, 1)[:, 0]
                pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
                correct = np.sum(pred_choice == batch_label)
                current_seen_class = [0 for _ in range(opt.n_classes)]
                current_correct_class = [0 for _ in range(opt.n_classes)]
                current_iou_deno_class = [0 for _ in range(opt.n_classes)]
                for l in range(opt.n_classes):

                    current_seen_class[l] = np.sum((batch_label2 == l))
                    current_correct_class[l] = np.sum((pred_val == l) & (batch_label2 == l))
                    current_iou_deno_class[l] = np.sum(((pred_val == l) | (batch_label2 == l)))
                m_iou = np.mean(
                    np.array(current_correct_class) / (np.array(current_iou_deno_class, dtype=np.float) + 1e-6))
                mean_miou.update(m_iou)
                mean_loss.update(loss.item())
                mean_acc.update(correct / float(opt.batch_size * NUM_POINT))

        validation_mIoU = mean_miou.avg
        validation_macc = mean_acc.avg
        validation_mloss = mean_loss.avg
        writer.add_scalar('validation_loss', validation_mloss, epoch)
        print('Epoch: %d, validation mean loss: %f' %(epoch, validation_mloss))
        writer.add_scalar('validation_accuracy', validation_macc, epoch)
        print('Epoch: %d, validation accuracy: %f' %(epoch, validation_macc))
        writer.add_scalar('validation_mIoU', validation_mIoU, epoch)
        print('Epoch: %d, validation point avg class IoU: %f' % (epoch,validation_mIoU))

        model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
        package ={
        'epoch': opt.epoch,
        'state_dict': model_cpu,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_miou':train_mIoU,
        'train_accuracy':train_macc,
        'train_loss':train_mloss,
        'validation_mIoU':validation_mIoU,
        'validation_macc':validation_macc,
        'validation_mloss':validation_mloss,
        'num_channel':opt.num_channel,
        'gpuNum': opt.gpuNum,
        'time':time.ctime()
        }
        torch.save(package,'saves/val_miou_%f_val_acc_%f_%d.pth' % (validation_mIoU, validation_macc, epoch))
        is_best = (best_value < validation_mIoU)
        print('Is Best? ',is_best)
        if (best_value < validation_mIoU):
            best_value = validation_mIoU
            torch.save(package,'saves/best_model.pth')
        print('Best IoU: %f' % (best_value))
        scheduler.step()
    opt.printer.info('Saving the final model.Finish!')
Ejemplo n.º 11
0
        train_loss += loss.item()
        scores_train = trainer.get_scores()

        batch_time.update(time.time() - start_time)
        start_time = time.time()

        if i_batch % 10 == 0 and local_rank == 0:
            tbar.set_description(
                'Train loss: %.4f; mIoU: %.4f; data time: %.2f; batch time: %.2f'
                % (train_loss / (i_batch + 1), scores_train["iou_mean"],
                   data_time.avg, batch_time.avg))

    if local_rank == 0:
        writer.add_scalar('loss', train_loss / len(tbar), epoch)
    trainer.reset_metrics()
    data_time.reset()
    batch_time.reset()

    if epoch % 1 == 0 and local_rank == 0:
        with torch.no_grad():
            model.eval()
            print("evaluating...")

            if test:
                tbar = tqdm(dataloader_test)
            else:
                tbar = tqdm(dataloader_val)

            start_time = time.time()
            for i_batch, sample in enumerate(tbar):
                data_time.update(time.time() - start_time)
Ejemplo n.º 12
0
    def train_one_epoch(self, epoch):
        train_errors = AverageMeter()
        train_losses = AverageMeter()
        train_iter = tqdm.tqdm(self.train_loader,
                               desc='Train Epoch',
                               total=self.n_batch_train,
                               leave=False)
        self.model.train()
        for i, batch in enumerate(train_iter):
            image = batch['image']
            gaze = batch['gaze']
            if self.pose_mode:
                pose = batch['pose']
                out = self.model(image, pose)
            else:
                out = self.model(image)
            num = image.size()[0]
            gaze_error_batch = np.mean(
                angular_error(out.cpu().data.numpy(),
                              gaze.cpu().data.numpy()))
            train_errors.update(gaze_error_batch.item(), num)

            loss_gaze = self.criterion(out, gaze)
            self.optimizer.zero_grad()
            # loss_gaze.backward()
            accelerator.backward(loss_gaze)
            self.optimizer.step()
            train_losses.update(loss_gaze.item(), num)

            if i % self.config.log_freq == 0:
                if self.config.wandb:
                    wandb.log({
                        'epoch': epoch,
                        "batch": i,
                        "Train Errors": train_errors.avg,
                        "Train Losses": train_losses.avg
                    })

                postfix = {'Error': train_errors.avg, 'Loss': train_losses.avg}
                train_iter.set_postfix(postfix)
                train_errors.reset()
                train_losses.reset()

        if self.use_val:
            self.model.eval()
            val_errors = AverageMeter()
            val_losses = AverageMeter()
            val_iter = tqdm.tqdm(self.val_loader,
                                 desc='Val',
                                 total=self.n_batch_val,
                                 leave=False)
            for i, batch in enumerate(val_iter):
                image = batch['image']
                gaze = batch['gaze']
                if self.pose_mode:
                    pose = batch['pose']
                    out = self.model(image, pose)
                else:
                    out = self.model(image)
                num = image.size()[0]
                gaze_error_batch = np.mean(
                    angular_error(out.cpu().data.numpy(),
                                  gaze.cpu().data.numpy()))
                val_errors.update(gaze_error_batch.item(), num)
                loss_gaze = self.criterion(out, gaze)
                val_losses.update(loss_gaze.item(), num)

                if i % self.config.log_freq == 0:
                    postfix = {'Error': val_errors.avg, 'Loss': val_losses.avg}
                    val_iter.set_postfix(postfix)

            if self.config.wandb:
                wandb.log({
                    'epoch': epoch,
                    "Val Errors": val_errors.avg,
                    "Val Losses": val_losses.avg
                })

        return train_errors.avg, train_losses.avg
Ejemplo n.º 13
0
    def train(self,
              dataset_train,
              dataset_val,
              criterion,
              optimizer_func,
              trainer_func,
              evaluator_func,
              collate,
              dataset_test=None,
              tester_func=None):
        if self.distributed:
            sampler_train = DistributedSampler(dataset_train, shuffle=True)
            dataloader_train = DataLoader(dataset_train,
                                          num_workers=self.cfg.num_workers,
                                          batch_size=self.cfg.batch_size,
                                          collate_fn=collate,
                                          sampler=sampler_train,
                                          pin_memory=True)
        else:
            dataloader_train = DataLoader(dataset_train,
                                          num_workers=self.cfg.num_workers,
                                          batch_size=self.cfg.batch_size,
                                          collate_fn=collate,
                                          shuffle=True,
                                          pin_memory=True)
        dataloader_val = DataLoader(dataset_val,
                                    num_workers=self.cfg.num_workers,
                                    batch_size=self.cfg.batch_size,
                                    collate_fn=collate,
                                    shuffle=False,
                                    pin_memory=True)
        # if dataset_test:
        #     dataloader_test = DataLoader(dataset_test, num_workers=self.cfg.num_workers, batch_size=self.cfg.batch_size, collate_fn=collate, shuffle=False, pin_memory=True)
        ###################################
        print("creating models......")
        model = self.model_loader(self.model,
                                  self.device,
                                  distributed=self.distributed,
                                  local_rank=self.local_rank,
                                  evaluation=True,
                                  ckpt_path=self.cfg.ckpt_path)

        ###################################
        num_epochs = self.cfg.num_epochs
        learning_rate = self.cfg.lr
        data_time = AverageMeter("DataTime", ':3.3f')
        batch_time = AverageMeter("BatchTime", ':3.3f')

        optimizer = optimizer_func(model, learning_rate=learning_rate)
        scheduler = LR_Scheduler(self.cfg.scheduler, learning_rate, num_epochs,
                                 len(dataloader_train))
        ##################################
        trainer = trainer_func(criterion, optimizer, self.cfg.n_class)
        evaluator = evaluator_func(self.cfg.n_class)
        if tester_func:
            tester = tester_func(self.cfg.n_class, self.cfg.num_workers,
                                 self.cfg.batch_size)

        evaluation = self.cfg.evaluation
        val_vis = self.cfg.val_vis
        best_pred = 0.0
        print("start training......")

        # log
        if self.local_rank == 0:
            f_log = open(self.cfg.log_path + self.cfg.task_name + ".log", 'w')
            log = self.cfg.task_name + '\n'
            for k, v in self.cfg.__dict__.items():
                log += str(k) + ' = ' + str(v) + '\n'
            print(log)
            f_log.write(log)
            f_log.flush()
        # writer
        if self.local_rank == 0:
            writer = SummaryWriter(log_dir=self.cfg.writer_path)
        writer_info = {}

        for epoch in range(num_epochs):
            optimizer.zero_grad()
            num_batch = len(dataloader_train)
            tbar = tqdm(dataloader_train)
            train_loss = 0

            start_time = time.time()
            model.train()
            for i_batch, sample in enumerate(tbar):
                data_time.update(time.time() - start_time)
                scheduler(optimizer, i_batch, epoch, best_pred)
                # loss = trainer.train(sample, model)
                if self.distributed:
                    loss = trainer.train(sample, model)
                else:
                    loss = trainer.train_acc(sample, model, i_batch, 2,
                                             num_batch)

                train_loss += loss.item()
                scores_train = trainer.get_scores()

                batch_time.update(time.time() - start_time)
                start_time = time.time()

                if i_batch % 20 == 0 and self.local_rank == 0:
                    tbar.set_description(
                        'Train loss: %.4f; mIoU: %.4f; data time: %.2f; batch time: %.2f'
                        % (train_loss /
                           (i_batch + 1), scores_train["iou_mean"],
                           data_time.avg, batch_time.avg))
                # break
            trainer.reset_metrics()
            data_time.reset()
            batch_time.reset()

            train_model_fr, train_seg_fr = trainer.calculate_avg_fr()

            if evaluation and epoch % 1 == 0 and self.local_rank == 0:
                with torch.no_grad():
                    model.eval()

                    ##--** evaluating **--
                    print("evaluating...")
                    tbar = tqdm(dataloader_val)
                    start_time = time.time()
                    for i_batch, sample in enumerate(tbar):
                        data_time.update(time.time() - start_time)
                        predictions = evaluator.eval(sample, model)
                        scores_val = evaluator.get_scores()

                        batch_time.update(time.time() - start_time)
                        if i_batch % 20 == 0 and self.local_rank == 0:
                            tbar.set_description(
                                'mIoU: %.4f; data time: %.2f; batch time: %.2f'
                                % (scores_val["iou_mean"], data_time.avg,
                                   batch_time.avg))

                        if val_vis and (
                                1 +
                                epoch) % 10 == 0:  # val set result visualize
                            for i in range(len(sample['id'])):
                                name = sample['id'][i] + '.png'
                                slide = name.split('_')[0]
                                slide_dir = os.path.join(
                                    self.cfg.val_output_path, slide)
                                if not os.path.exists(slide_dir):
                                    os.makedirs(slide_dir)
                                predictions_rgb = class_to_RGB(predictions[i])
                                predictions_rgb = cv2.cvtColor(
                                    predictions_rgb, cv2.COLOR_BGR2RGB)
                                cv2.imwrite(os.path.join(slide_dir, name),
                                            predictions_rgb)
                                # writer_info.update(mask=mask_rgb, prediction=predictions_rgb)
                        start_time = time.time()
                        # break
                    data_time.reset()
                    batch_time.reset()
                    scores_val = evaluator.get_scores()
                    evaluator.reset_metrics()

                    val_model_fr, val_seg_fr = evaluator.calculate_avg_fr()

                    ##--** testing **--
                    if dataset_test:
                        print("testing...")
                        num_slides = len(dataset_test.slides)
                        tbar2 = tqdm(range(num_slides))
                        start_time = time.time()
                        for i in tbar2:
                            dataset_test.get_patches_from_index(i)
                            data_time.update(time.time() - start_time)
                            predictions, output, _ = tester.inference(
                                dataset_test, model)
                            mask = dataset_test.get_slide_mask_from_index(i)
                            tester.update_scores(mask, predictions)
                            scores_test = tester.get_scores()
                            batch_time.update(time.time() - start_time)
                            tbar2.set_description(
                                'mIoU: %.4f; data time: %.2f; slide time: %.2f'
                                % (scores_test["iou_mean"], data_time.avg,
                                   batch_time.avg))

                            output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
                            cv2.imwrite(
                                os.path.join(self.cfg.test_output_path,
                                             dataset_test.slide + '.png'),
                                output)
                            # writer_info.update(mask=mask_rgb, prediction=predictions_rgb)
                            start_time = time.time()
                            # break
                        data_time.reset()
                        batch_time.reset()
                        scores_test = tester.get_scores()
                        tester.reset_metrics()

                        test_model_fr, test_seg_fr = tester.calculate_avg_fr()

                    # save model
                    best_pred = save_ckpt_model(model, self.cfg, scores_val,
                                                best_pred, epoch)
                    # log
                    update_log(f_log,
                               self.cfg,
                               scores_train,
                               scores_val, [train_model_fr, train_seg_fr],
                               [val_model_fr, val_seg_fr],
                               epoch,
                               scores_test=scores_test,
                               test_fr=[test_model_fr, test_seg_fr])
                    # writer\
                    if self.cfg.n_class == 4:
                        writer_info.update(loss=train_loss / len(tbar),
                                           lr=optimizer.param_groups[0]['lr'],
                                           mIOU={
                                               "train":
                                               scores_train["iou_mean"],
                                               "val": scores_val["iou_mean"],
                                               "test": scores_test["iou_mean"],
                                           },
                                           mucosa_iou={
                                               "train": scores_train["iou"][2],
                                               "val": scores_val["iou"][2],
                                               "test": scores_test["iou"][2],
                                           },
                                           tumor_iou={
                                               "train": scores_train["iou"][3],
                                               "val": scores_val["iou"][3],
                                               "test": scores_test["iou"][3],
                                           },
                                           mucosa_model_fr={
                                               "train": train_model_fr[0],
                                               "val": val_model_fr[0],
                                               "test": test_model_fr[0],
                                           },
                                           tumor_model_fr={
                                               "train": train_model_fr[1],
                                               "val": val_model_fr[1],
                                               "test": val_model_fr[1],
                                           },
                                           mucosa_seg_fr={
                                               "train": train_seg_fr[0],
                                               "val": val_seg_fr[0],
                                               "test": test_seg_fr[0],
                                           },
                                           tumor_seg_fr={
                                               "train": train_seg_fr[1],
                                               "val": val_seg_fr[1],
                                               "test": test_seg_fr[1],
                                           })
                    else:
                        writer_info.update(loss=train_loss / len(tbar),
                                           lr=optimizer.param_groups[0]['lr'],
                                           mIOU={
                                               "train":
                                               scores_train["iou_mean"],
                                               "val": scores_val["iou_mean"],
                                               "test": scores_test["iou_mean"],
                                           },
                                           merge_iou={
                                               "train": scores_train["iou"][2],
                                               "val": scores_val["iou"][2],
                                               "test": scores_test["iou"][2],
                                           },
                                           merge_model_fr={
                                               "train": train_model_fr[0],
                                               "val": val_model_fr[0],
                                               "test": test_model_fr[0],
                                           },
                                           merge_seg_fr={
                                               "train": train_seg_fr[0],
                                               "val": val_seg_fr[0],
                                               "test": val_seg_fr[0],
                                           })
                    update_writer(writer, writer_info, epoch)
        if self.local_rank == 0:
            f_log.close()