Exemplo n.º 1
0
    timer = Timer()
    timer.reset()

    for e in range(num_epochs):
        e += 1
        args = (e, num_epochs, num_games_per_epoch)

        workers = [args] * parallelism

        print('Epoch', e, 'playing starting')
        if parallelism > 1:
            result = pool.map(playing_step, workers)
            X, y = unzip(result)
            X = np.vstack(X)
            y = np.vstack(y)
        else:
            X, y = playing_step(args, load=False)
        print('Epoch', e, 'playing finished')

        print('Epoch', e, 'training starting')
        p = PolicyNetPlayer()
        if parallelism > 1:
            PolicyNetPlayer._cached_net = None
        p.start()
        p.train(X, y, verbose=1)
        p.end()
        print('Epoch', e, 'training finished')

        print(timer.eta(e + 1, num_epochs))
Exemplo n.º 2
0
class Trainer(object):
    def __init__(self, mode):
        # Define Saver
        self.saver = Saver(opt, mode)
        self.logger = self.saver.logger

        # visualize
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Dataset dataloader
        self.train_dataset, self.train_loader = make_data_loader(opt)
        self.nbatch_train = len(self.train_loader)
        self.val_dataset, self.val_loader = make_data_loader(opt, mode="val")
        self.nbatch_val = len(self.val_loader)

        # model
        if opt.sync_bn is None and len(opt.gpu_id) > 1:
            opt.sync_bn = True
        else:
            opt.sync_bn = False
        model = CRG2Net(opt)
        self.model = model.to(opt.device)

        # Loss
        if opt.use_balanced_weights:
            classes_weights_file = osp.join(opt.root_dir, 'train_classes_weights.npy')
            if os.path.isfile(classes_weights_file):
                weight = np.load(classes_weights_file)
            else:
                weight = calculate_weigths_labels(
                    self.train_loader, opt.root_dir)
            print(weight)
            opt.loss_region['weight'] = weight
        self.loss_region = build_loss(opt.loss_region)
        self.loss_density = build_loss(opt.loss_density)

        # Define Evaluator
        self.evaluator = Evaluator(dataset=opt.dataset)  # use region to eval: class_num is 2

        # Resuming Checkpoint
        self.best_pred = 0.0
        self.start_epoch = 0
        if opt.resume:
            if os.path.isfile(opt.pre):
                print("=> loading checkpoint '{}'".format(opt.pre))
                checkpoint = torch.load(opt.pre)
                self.start_epoch = checkpoint['epoch']
                self.best_pred = checkpoint['best_pred']
                self.model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(opt.pre, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(opt.pre))

        if len(opt.gpu_id) > 1:
            self.logger.info("Using multiple gpu")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=opt.gpu_id)

        # Define Optimizer and Lr Scheduler
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=opt.lr,
                                         momentum=opt.momentum,
                                         weight_decay=opt.decay)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[round(opt.epochs * x) for x in opt.steps],
            gamma=opt.gamma)

        # Time
        self.loss_hist = collections.deque(maxlen=500)
        self.timer = Timer(opt.epochs, self.nbatch_train, self.nbatch_val)
        self.step_time = collections.deque(maxlen=opt.print_freq)

    def train(self, epoch):
        self.model.train()
        if opt.freeze_bn:
            self.model.module.freeze_bn() if len(opt.gpu_id) > 1 \
                else self.model.freeze_bn()
        last_time = time.time()
        epoch_loss = []
        for iter_num, sample in enumerate(self.train_loader):
            # if iter_num >= 0: break
            try:
                imgs = sample["image"].to(opt.device)
                density_gt = sample["label"].to(opt.device)
                region_gt = (sample["label"] > 0).float().to(opt.device)

                region_pred, density_pred = self.model(imgs)

                region_loss = self.loss_region(region_pred, region_gt)
                density_loss = self.loss_density(density_pred, density_gt)
                loss = region_loss + density_loss
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
                self.loss_hist.append(float(loss))
                epoch_loss.append(float(loss.cpu().item()))

                self.optimizer.step()
                self.optimizer.zero_grad()
                # self.scheduler(self.optimizer, iter_num, epoch)

                # Visualize
                global_step = iter_num + self.nbatch_train * epoch + 1
                self.writer.add_scalar('train/loss', loss.cpu().item(), global_step) 
                batch_time = time.time() - last_time
                last_time = time.time()
                eta = self.timer.eta(global_step, batch_time)
                self.step_time.append(batch_time)
                if global_step % opt.print_freq == 0:
                    printline = ('Epoch: [{}][{}/{}] '
                                 'lr: {:1.5f}, '  # 10x:{:1.5f}), '
                                 'eta: {}, time: {:1.1f}, '
                                 'region loss: {:1.4f}, '
                                 'density loss: {:1.4f}, '
                                 'loss: {:1.4f}').format(
                                    epoch, iter_num+1, self.nbatch_train,
                                    self.optimizer.param_groups[0]['lr'],
                                    # self.optimizer.param_groups[1]['lr'],
                                    eta, np.sum(self.step_time),
                                    region_loss, density_loss,
                                    np.mean(self.loss_hist))
                    self.logger.info(printline)

                del loss, region_loss, density_loss

            except Exception as e:
                print(e)
                continue

        self.scheduler.step()

    def validate(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        SMAE = 0
        with torch.no_grad():
            tbar = tqdm(self.val_loader, desc='\r')
            for i, sample in enumerate(tbar):
                # if i > 3: break
                imgs = sample['image'].to(opt.device)
                density_gt = sample["label"].to(opt.device)
                region_gt = (sample["label"] > 0).float()
                path = sample["path"]

                region_pred, density_pred = self.model(imgs)

                # Visualize
                global_step = i + self.nbatch_val * epoch + 1
                if global_step % opt.plot_every == 0:
                    # pred = output.data.cpu().numpy()
                    pred = torch.argmax(region_pred, dim=1)
                    self.summary.visualize_image(self.writer,
                                                 opt.dataset,
                                                 imgs,
                                                 density_gt,
                                                 pred,
                                                 global_step)

                # metrics
                target = region_gt.numpy()
                pred = region_pred.data.cpu().numpy()
                pred = np.argmax(pred, axis=1).reshape(target.shape)
                self.evaluator.add_batch(target, pred, path)
                density_pred = density_pred.clamp(min=0.00018) * region_pred.argmax(1, keepdim=True)
                SMAE += (density_gt.sum() - density_pred.sum()).abs().item()

            # Fast test during the training
            MAE = SMAE / (len(self.val_dataset) * opt.norm_cfg['para'])
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            RRecall = self.evaluator.Region_Recall()
            RNum = self.evaluator.Region_Num()
            result = 2 / (1 / mIoU + 1 / RRecall)
            titles = ["mIoU", "MAE", "Acc", "Acc_class", "fwIoU", "RRecall", "RNum", "Result"]
            values = [mIoU, MAE, Acc, Acc_class, FWIoU, RRecall, RNum, result]
            for title, value in zip(titles, values):
                self.writer.add_scalar('val/'+title, value, epoch)

            printline = ("Val: mIoU: {:.4f}, MAE: {:.4f}, "
                         "Acc: {:.4f}, Acc_class: {:.4f}, fwIoU: {:.4f}, "
                         "RRecall: {:.4f}, RNum: {:.1f}, Result: {:.4f}]").format(
                            *values)
            self.logger.info(printline)

        return result
Exemplo n.º 3
0
class Trainer(object):
    def __init__(self, mode):
        # Define Saver
        self.saver = Saver(opt, mode)
        self.logger = self.saver.logger

        # Visualize
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Dataset dataloader
        self.train_dataset, self.train_loader = make_data_loader(opt)
        self.nbatch_train = len(self.train_loader)
        self.val_dataset, self.val_loader = make_data_loader(opt, mode="val")
        self.nbatch_val = len(self.val_loader)

        # Model
        if opt.sync_bn is None and len(opt.gpu_id) > 1:
            opt.sync_bn = True
        else:
            opt.sync_bn = False
        # model = DeepLab(opt)
        # model = CSRNet()
        model = CRGNet(opt)
        model_info(model, self.logger)
        self.model = model.to(opt.device)

        # Loss
        if opt.use_balanced_weights:
            classes_weights_file = osp.join(opt.root_dir, 'train_classes_weights.npy')
            if os.path.isfile(classes_weights_file):
                weight = np.load(classes_weights_file)
            else:
                weight = calculate_weigths_labels(
                    self.train_loader, opt.root_dir)
            print(weight)
            opt.loss['weight'] = weight
        self.loss = build_loss(opt.loss)

        # Define Evaluator
        self.evaluator = Evaluator()  # use region to eval: class_num is 2

        # Resuming Checkpoint
        self.best_pred = 0.0
        self.start_epoch = 0
        if opt.resume:
            if os.path.isfile(opt.pre):
                print("=> loading checkpoint '{}'".format(opt.pre))
                checkpoint = torch.load(opt.pre)
                self.start_epoch = checkpoint['epoch']
                self.best_pred = checkpoint['best_pred']
                self.model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(opt.pre, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(opt.pre))

        if len(opt.gpu_id) > 1:
            print("Using multiple gpu")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=opt.gpu_id)

        # Define Optimizer
        # train_params = [{'params': model.get_1x_lr_params(), 'lr': opt.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': opt.lr * 10}]
        # self.optimizer = torch.optim.SGD(train_params,
        #                                  momentum=opt.momentum,
        #                                  weight_decay=opt.decay)
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=opt.lr,
                                         momentum=opt.momentum,
                                         weight_decay=opt.decay)

        # Define lr scheduler
        # self.scheduler = LR_Scheduler(mode=opt.lr_scheduler,
        #                               base_lr=opt.lr,
        #                               num_epochs=opt.epochs,
        #                               iters_per_epoch=self.nbatch_train,
        #                               lr_step=140)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[round(opt.epochs * x) for x in opt.steps],
            gamma=opt.gamma)

        # Time
        self.loss_hist = collections.deque(maxlen=500)
        self.timer = Timer(opt.epochs, self.nbatch_train, self.nbatch_val)
        self.step_time = collections.deque(maxlen=opt.print_freq)

    def train(self, epoch):
        self.model.train()
        if opt.freeze_bn:
            self.model.module.freeze_bn() if len(opt.gpu_id) > 1 \
                else self.model.freeze_bn()
        last_time = time.time()
        epoch_loss = []
        for iter_num, sample in enumerate(self.train_loader):
            # if iter_num >= 100: break
            try:
                imgs = sample["image"].to(opt.device)
                labels = sample["label"].to(opt.device)

                output = self.model(imgs)

                loss = self.loss(output, labels)
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3)
                self.loss_hist.append(float(loss))
                epoch_loss.append(float(loss.cpu().item()))

                self.optimizer.step()
                self.optimizer.zero_grad()
                # self.scheduler(self.optimizer, iter_num, epoch)

                # Visualize
                global_step = iter_num + self.nbatch_train * epoch + 1
                self.writer.add_scalar('train/loss', loss.cpu().item(), global_step) 
                batch_time = time.time() - last_time
                last_time = time.time()
                eta = self.timer.eta(global_step, batch_time)
                self.step_time.append(batch_time)
                if global_step % opt.print_freq == 0:
                    printline = ('Epoch: [{}][{}/{}] '
                                 'lr: {:1.5f}, '  # 10x:{:1.5f}), '
                                 'eta: {}, time: {:1.1f}, '
                                 'Loss: {:1.4f} '.format(
                                    epoch, iter_num+1, self.nbatch_train,
                                    self.optimizer.param_groups[0]['lr'],
                                    # self.optimizer.param_groups[1]['lr'],
                                    eta, np.sum(self.step_time),
                                    np.mean(self.loss_hist)))
                    self.logger.info(printline)

                del loss

            except Exception as e:
                print(e)
                continue

        self.scheduler.step()

    def validate(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        test_loss = 0.0
        with torch.no_grad():
            tbar = tqdm(self.val_loader, desc='\r')
            for i, sample in enumerate(tbar):
                # if i > 3: break
                imgs = sample['image'].to(opt.device)
                labels = sample['label'].to(opt.device)
                path = sample["path"]

                output = self.model(imgs)

                loss = self.loss(output, labels)
                test_loss += loss.item()
                tbar.set_description('Test loss: %.4f' % (test_loss / (i + 1)))

                # Visualize
                global_step = i + self.nbatch_val * epoch + 1
                if global_step % opt.plot_every == 0:
                    # pred = output.data.cpu().numpy()
                    if output.shape[1] > 1:
                        pred = torch.argmax(output, dim=1)
                    else:
                        pred = torch.clamp(output, min=0)
                    self.summary.visualize_image(self.writer,
                                                 opt.dataset,
                                                 imgs,
                                                 labels,
                                                 pred,
                                                 global_step)

                # metrics
                pred = output.data.cpu().numpy()
                target = labels.cpu().numpy() > 0
                if pred.shape[1] > 1:
                    pred = np.argmax(pred, axis=1)
                pred = (pred > opt.region_thd).reshape(target.shape)
                self.evaluator.add_batch(target, pred, path, opt.dataset)

            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            RRecall = self.evaluator.Region_Recall()
            RNum = self.evaluator.Region_Num()
            mean_loss = test_loss / self.nbatch_val
            result = 2 / (1 / mIoU + 1 / RRecall)
            titles = ["mean_loss", "mIoU", "Acc", "Acc_class", "fwIoU", "RRecall", "RNum", "Result"]
            values = [mean_loss, mIoU, Acc, Acc_class, FWIoU, RRecall, RNum, result]
            for title, value in zip(titles, values):
                self.writer.add_scalar('val/'+title, value, epoch)

            printline = ("Val, mean_loss: {:.4f}, mIoU: {:.4f}, "
                         "Acc: {:.4f}, Acc_class: {:.4f}, fwIoU: {:.4f}, "
                         "RRecall: {:.4f}, RNum: {:.1f}]").format(
                            *values[:-1])
            self.logger.info(printline)

        return result
Exemplo n.º 4
0
class Trainer(object):
    def __init__(self, mode):
        # Define Saver
        self.saver = Saver(opt, mode)
        self.logger = self.saver.logger

        # visualize
        self.summary = TensorboardSummary(self.saver.experiment_dir, opt)
        self.writer = self.summary.writer

        # Define Dataloader
        # train dataset
        self.train_dataset, self.train_loader = make_data_loader(opt, train=True)
        self.nbatch_train = len(self.train_loader)
        self.num_classes = self.train_dataset.num_classes

        # val dataset
        self.val_dataset, self.val_loader = make_data_loader(opt, train=False)
        self.nbatch_val = len(self.val_loader)

        # Define Network
        # initilize the network here.
        self.model = Model(opt, self.num_classes)
        self.model = self.model.to(opt.device)

        # Detection post process(NMS...)
        self.post_pro = PostProcess(**opt.nms)

        # Define Optimizer
        if opt.adam:
            self.optimizer = optim.Adam(self.model.parameters(), lr=opt.lr)
        else:
            self.optimizer = optim.SGD(self.model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.decay)

        # Apex
        if opt.use_apex:
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1')

        # Resuming Checkpoint
        self.best_pred = 0.0
        self.start_epoch = 0
        if opt.resume:
            if os.path.isfile(opt.pre):
                print("=> loading checkpoint '{}'".format(opt.pre))
                checkpoint = torch.load(opt.pre)
                self.start_epoch = checkpoint['epoch'] + 1
                self.best_pred = checkpoint['best_pred']
                self.model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(opt.pre, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(opt.pre))

        # Define lr scherduler
        # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        #     self.optimizer, patience=3, verbose=True)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[round(opt.epochs * x) for x in opt.steps],
            gamma=opt.gamma)
        self.scheduler.last_epoch = self.start_epoch - 1

        # Using mul gpu
        if len(opt.gpu_id) > 1:
            self.logger.info("Using multiple gpu")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=opt.gpu_id)

        # metrics
        if opt.eval_type == 'cocoeval':
            self.eval = COCO_eval(self.val_dataset.coco)
        else:
            self.eval = VOC_eval(self.num_classes)

        self.loss_hist = collections.deque(maxlen=500)
        self.timer = Timer(opt.epochs, self.nbatch_train, self.nbatch_val)
        self.step_time = collections.deque(maxlen=opt.print_freq)

    def training(self, epoch):
        self.model.train()
        epoch_loss = []
        last_time = time.time()
        for iter_num, data in enumerate(self.train_loader):
            # if iter_num >= 0: break
            try:
                self.optimizer.zero_grad()
                inputs = data['img'].to(opt.device)
                targets = data['annot'].to(opt.device)

                losses = self.model(inputs, targets)
                loss, log_vars = parse_losses(losses)

                if bool(loss == 0):
                    continue
                if opt.use_apex:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), opt.grad_clip)
                self.optimizer.step()
                self.loss_hist.append(float(loss.cpu().item()))
                epoch_loss.append(float(loss.cpu().item()))

                # visualize
                global_step = iter_num + self.nbatch_train * epoch + 1
                loss_logs = ""
                for _key, _value in log_vars.items():
                    loss_logs += "{}: {:.4f}  ".format(_key, _value)
                    self.writer.add_scalar('train/{}'.format(_key),
                                           _value,
                                           global_step)

                batch_time = time.time() - last_time
                last_time = time.time()
                eta = self.timer.eta(global_step, batch_time)
                self.step_time.append(batch_time)
                if global_step % opt.print_freq == 0:
                    printline = ("Epoch: [{}][{}/{}]  "
                                 "lr: {}  eta: {}  time: {:1.1f}  "
                                 "{}"
                                 "Running loss: {:1.5f}").format(
                                    epoch, iter_num + 1, self.nbatch_train,
                                    self.optimizer.param_groups[0]['lr'],
                                    eta, np.sum(self.step_time),
                                    loss_logs,
                                    np.mean(self.loss_hist))
                    self.logger.info(printline)

            except Exception as e:
                print(e)
                continue

        # self.scheduler.step(np.mean(epoch_loss))
        self.scheduler.step()

    def validate(self, epoch):
        self.model.eval()
        # torch.backends.cudnn.benchmark = False
        # self.model.apply(uninplace_relu)
        # start collecting results
        with torch.no_grad():
            for ii, data in enumerate(self.val_loader):
                # if ii > 0: break
                scale = data['scale']
                index = data['index']
                inputs = data['img'].to(opt.device)
                targets = data['annot']

                # run network
                scores, labels, boxes = self.model(inputs)

                scores_bt, labels_bt, boxes_bt = self.post_pro(
                    scores, labels, boxes, inputs.shape[-2:])

                outputs = []
                for k in range(len(boxes_bt)):
                    outputs.append(torch.cat((
                        boxes_bt[k].clone(),
                        labels_bt[k].clone().unsqueeze(1).float(),
                        scores_bt[k].clone().unsqueeze(1)),
                        dim=1))

                # visualize
                global_step = ii + self.nbatch_val * epoch
                if global_step % opt.plot_every == 0:
                    self.summary.visualize_image(
                        inputs, targets, outputs,
                        self.val_dataset.labels,
                        global_step)

                # eval
                if opt.eval_type == "voceval":
                    self.eval.statistics(outputs, targets, iou_thresh=0.5)

                elif opt.eval_type == "cocoeval":
                    self.eval.statistics(outputs, scale, index)

                print('{}/{}'.format(ii, len(self.val_loader)), end='\r')

            if opt.eval_type == "voceval":
                stats, ap_class = self.eval.metric()
                for key, value in stats.items():
                    self.writer.add_scalar('val/{}'.format(key), value.mean(), epoch)
                self.saver.save_voc_eval_result(stats, ap_class, self.val_dataset.labels)
                return stats['AP']

            elif opt.eval_type == "cocoeval":
                stats = self.eval.metirc()
                self.saver.save_coco_eval_result(stats)
                self.writer.add_scalar('val/mAP', stats[0], epoch)
                return stats[0]

            else:
                raise NotImplementedError