Beispiel #1
0
    def _make_model(self):
        # prepare network
        self.logger.info("Creating graph and optimizer...")
        model = get_model('train', self.joint_num)
        model = DataParallel(model).cuda()
        optimizer = self.get_optimizer(model)
        if cfg.continue_train:
            start_epoch, model, optimizer = self.load_model(model, optimizer)
        else:
            start_epoch = 0
        model.train()

        self.start_epoch = start_epoch
        self.model = model
        self.optimizer = optimizer
Beispiel #2
0
class SSRunner(object):
    def __init__(self, config):
        self.config = config

        # Data
        self.dataset_ss_train, _, self.dataset_ss_val = DatasetUtil.get_dataset_by_type(
            DatasetUtil.dataset_type_ss,
            self.config.ss_size,
            is_balance=self.config.is_balance_data,
            data_root=self.config.data_root_path,
            train_label_path=self.config.label_path,
            max_size=self.config.max_size)
        self.data_loader_ss_train = DataLoader(self.dataset_ss_train,
                                               self.config.ss_batch_size,
                                               True,
                                               num_workers=16,
                                               drop_last=True)
        self.data_loader_ss_val = DataLoader(self.dataset_ss_val,
                                             self.config.ss_batch_size,
                                             False,
                                             num_workers=16,
                                             drop_last=True)

        # Model
        self.net = self.config.Net(num_classes=self.config.ss_num_classes,
                                   output_stride=self.config.output_stride,
                                   arch=self.config.arch)

        if self.config.only_train_ss:
            self.net = BalancedDataParallel(0, self.net, dim=0).cuda()
        else:
            self.net = DataParallel(self.net).cuda()
            pass
        cudnn.benchmark = True

        # Optimize
        self.optimizer = optim.SGD(params=[
            {
                'params': self.net.module.model.backbone.parameters(),
                'lr': self.config.ss_lr
            },
            {
                'params': self.net.module.model.classifier.parameters(),
                'lr': self.config.ss_lr * 10
            },
        ],
                                   lr=self.config.ss_lr,
                                   momentum=0.9,
                                   weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=self.config.ss_milestones, gamma=0.1)

        # Loss
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=255,
                                           reduction='mean').cuda()
        pass

    def train_ss(self, start_epoch=0, model_file_name=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        # self.eval_ss(epoch=0)
        best_iou = 0.0

        for epoch in range(start_epoch, self.config.ss_epoch_num):
            Tools.print()
            Tools.print('Epoch:{:2d}, lr={:.6f} lr2={:.6f}'.format(
                epoch, self.optimizer.param_groups[0]['lr'],
                self.optimizer.param_groups[1]['lr']),
                        txt_path=self.config.ss_save_result_txt)

            ###########################################################################
            # 1 训练模型
            all_loss = 0.0
            self.net.train()
            if self.config.is_balance_data:
                self.dataset_ss_train.reset()
                pass
            for i, (inputs,
                    labels) in tqdm(enumerate(self.data_loader_ss_train),
                                    total=len(self.data_loader_ss_train)):
                inputs, labels = inputs.float().cuda(), labels.long().cuda()
                self.optimizer.zero_grad()

                result = self.net(inputs)
                loss = self.ce_loss(result, labels)

                loss.backward()
                self.optimizer.step()

                all_loss += loss.item()

                if (i + 1) % (len(self.data_loader_ss_train) // 10) == 0:
                    score = self.eval_ss(epoch=epoch)
                    mean_iou = score["Mean IoU"]
                    if mean_iou > best_iou:
                        best_iou = mean_iou
                        save_file_name = Tools.new_dir(
                            os.path.join(
                                self.config.ss_model_dir,
                                "ss_{}_{}_{}.pth".format(epoch, i, best_iou)))
                        torch.save(self.net.state_dict(), save_file_name)
                        Tools.print("Save Model to {}".format(save_file_name),
                                    txt_path=self.config.ss_save_result_txt)
                        Tools.print()
                    pass
                pass
            self.scheduler.step()
            ###########################################################################

            Tools.print("[E:{:3d}/{:3d}] ss loss:{:.4f}".format(
                epoch, self.config.ss_epoch_num,
                all_loss / len(self.data_loader_ss_train)),
                        txt_path=self.config.ss_save_result_txt)

            ###########################################################################
            # 2 保存模型
            if epoch % self.config.ss_save_epoch_freq == 0:
                Tools.print()
                save_file_name = Tools.new_dir(
                    os.path.join(self.config.ss_model_dir,
                                 "ss_{}.pth".format(epoch)))
                torch.save(self.net.state_dict(), save_file_name)
                Tools.print("Save Model to {}".format(save_file_name),
                            txt_path=self.config.ss_save_result_txt)
                Tools.print()
                pass
            ###########################################################################

            ###########################################################################
            # 3 评估模型
            if epoch % self.config.ss_eval_epoch_freq == 0:
                score = self.eval_ss(epoch=epoch)
                pass
            ###########################################################################

            pass

        # Final Save
        Tools.print()
        save_file_name = Tools.new_dir(
            os.path.join(self.config.ss_model_dir,
                         "ss_final_{}.pth".format(self.config.ss_epoch_num)))
        torch.save(self.net.state_dict(), save_file_name)
        Tools.print("Save Model to {}".format(save_file_name),
                    txt_path=self.config.ss_save_result_txt)
        Tools.print()

        self.eval_ss(epoch=self.config.ss_epoch_num)
        pass

    def eval_ss(self, epoch=0, model_file_name=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        self.net.eval()
        metrics = StreamSegMetrics(self.config.ss_num_classes)
        with torch.no_grad():
            for i, (inputs,
                    labels) in tqdm(enumerate(self.data_loader_ss_val),
                                    total=len(self.data_loader_ss_val)):
                inputs = inputs.float().cuda()
                labels = labels.long().cuda()
                outputs = self.net(inputs)
                preds = outputs.detach().max(dim=1)[1].cpu().numpy()
                targets = labels.cpu().numpy()

                metrics.update(targets, preds)
                pass
            pass

        score = metrics.get_results()
        Tools.print("{} {}".format(epoch, metrics.to_str(score)),
                    txt_path=self.config.ss_save_result_txt)
        return score

    def inference_ss(self,
                     model_file_name=None,
                     data_loader=None,
                     save_path=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        final_save_path = Tools.new_dir("{}_final".format(save_path))

        self.net.eval()
        metrics = StreamSegMetrics(self.config.ss_num_classes)
        with torch.no_grad():
            for i, (inputs, labels,
                    image_info_list) in tqdm(enumerate(data_loader),
                                             total=len(data_loader)):
                assert len(image_info_list) == 1

                # 标签
                max_size = 1000
                size = Image.open(image_info_list[0]).size
                basename = os.path.basename(image_info_list[0])
                final_name = os.path.join(final_save_path,
                                          basename.replace(".JPEG", ".png"))
                if os.path.exists(final_name):
                    continue

                if size[0] < max_size and size[1] < max_size:
                    targets = F.interpolate(torch.unsqueeze(
                        labels[0].float().cuda(), dim=0),
                                            size=(size[1], size[0]),
                                            mode="nearest").detach().cpu()
                else:
                    targets = F.interpolate(torch.unsqueeze(labels[0].float(),
                                                            dim=0),
                                            size=(size[1], size[0]),
                                            mode="nearest")
                targets = targets[0].long().numpy()

                # 预测
                outputs = 0
                for input_index, input_one in enumerate(inputs):
                    output_one = self.net(input_one.float().cuda())
                    if size[0] < max_size and size[1] < max_size:
                        outputs += F.interpolate(
                            output_one,
                            size=(size[1], size[0]),
                            mode="bilinear",
                            align_corners=False).detach().cpu()
                    else:
                        outputs += F.interpolate(output_one.detach().cpu(),
                                                 size=(size[1], size[0]),
                                                 mode="bilinear",
                                                 align_corners=False)
                        pass
                    pass
                outputs = outputs / len(inputs)
                preds = outputs.max(dim=1)[1].numpy()

                # 计算
                metrics.update(targets, preds)

                if save_path:
                    Image.open(image_info_list[0]).save(
                        os.path.join(save_path, basename))
                    DataUtil.gray_to_color(
                        np.asarray(targets[0], dtype=np.uint8)).save(
                            os.path.join(save_path,
                                         basename.replace(".JPEG", "_l.png")))
                    DataUtil.gray_to_color(np.asarray(
                        preds[0], dtype=np.uint8)).save(
                            os.path.join(save_path,
                                         basename.replace(".JPEG", ".png")))
                    Image.fromarray(np.asarray(
                        preds[0], dtype=np.uint8)).save(final_name)
                    pass
                pass
            pass

        score = metrics.get_results()
        Tools.print("{}".format(metrics.to_str(score)),
                    txt_path=self.config.ss_save_result_txt)
        return score

    def load_model(self, model_file_name):
        Tools.print("Load model form {}".format(model_file_name),
                    txt_path=self.config.ss_save_result_txt)
        checkpoint = torch.load(model_file_name)

        if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1:
            # checkpoint = {key.replace("module.", ""): checkpoint[key] for key in checkpoint}
            pass

        self.net.load_state_dict(checkpoint, strict=True)
        Tools.print("Restore from {}".format(model_file_name),
                    txt_path=self.config.ss_save_result_txt)
        pass

    def stat(self):
        stat(self.net, (3, self.config.ss_size, self.config.ss_size))
        pass

    pass
Beispiel #3
0
     pre_run = np.random.randint((eo + 1 - mile_stone) *
                                 track_step + 1)
     second = int(pre_run)
     third = second + track_step
 print_flag = i < epoch_outer or i % 10 == 1
 trk_state = model.module.init_state(batch_size=batch_size)
 data = to_device(data, device=device)
 model.eval()
 for j in range(second):
     frame_data = get_frame_data(data, j)
     with torch.no_grad():
         out, trk_state = model(frame_data, state=trk_state)
         trk_state = model.module.map_gt_ids(out,
                                             frame_data,
                                             state=trk_state)
 model.train()
 opt.zero_grad()
 losses = []
 for j in range(second, third):
     frame_data = get_frame_data(data, j)
     # print(model._cs_pos[0].weight.mean(), 'before')
     output, trk_state = model(frame_data, state=trk_state)
     # input()
     trk_state = model.module.map_gt_ids(output,
                                         frame_data,
                                         state=trk_state)
     loss_j = model.module.loss(output,
                                frame_data,
                                verbose=print_flag)
     losses.append(loss_j)
     # print(model._cs_pos[0].weight.mean(), 'after')
Beispiel #4
0
def train(args, pt_dir, chkpt_path, trainloader, devloader, writer, logger, hp,
          hp_str):

    model = get_SLOCountNet(hp).cuda()

    print("FOV: {}", model.get_fov(hp.features.n_fft))
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("N_parameters : {}".format(params))
    model = DataParallel(model)

    if hp.train.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam)
    else:
        raise Exception("%s optimizer not supported" % hp.train.optimizer)

    epoch = 0
    best_loss = np.inf

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch = checkpoint['step']

        # will use new given hparams.
        if hp_str != checkpoint['hp_str']:
            logger.warning("New hparams is different from checkpoint.")
    else:
        logger.info("Starting new training run")

    try:

        for epoch in range(epoch, hp.train.n_epochs):

            vad_scores = Binarymetrics.BinaryMeter()  # activity scores
            vod_scores = Binarymetrics.BinaryMeter()  # overlap scores
            count_scores = Binarymetrics.MultiMeter()  # Countnet scores

            model.train()
            tot_loss = 0

            with tqdm(trainloader) as t:
                t.set_description("Epoch: {}".format(epoch))

                for count, batch in enumerate(trainloader):

                    features, labels = batch
                    features = features.cuda()
                    labels = labels.cuda()

                    preds = model(features)

                    loss = criterion(preds, labels)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    # compute proper metrics for VAD
                    loss = loss.item()

                    if loss > 1e8 or math.isnan(loss):  # check if exploded
                        logger.error("Loss exploded to %.02f at step %d!" %
                                     (loss, epoch))
                        raise Exception("Loss exploded")

                    VADpreds = torch.sum(torch.exp(preds[:, 1:5, :]),
                                         dim=1).unsqueeze(1)
                    VADlabels = torch.sum(labels[:, 1:5, :],
                                          dim=1).unsqueeze(1)
                    vad_scores.update(VADpreds, VADlabels)

                    VODpreds = torch.sum(torch.exp(preds[:, 2:5, :]),
                                         dim=1).unsqueeze(1)
                    VODlabels = torch.sum(labels[:, 2:5, :],
                                          dim=1).unsqueeze(1)
                    vod_scores.update(VODpreds, VODlabels)

                    count_scores.update(
                        torch.argmax(torch.exp(preds), 1).unsqueeze(1),
                        torch.argmax(labels, 1).unsqueeze(1))

                    tot_loss += loss

                    vad_fa = vad_scores.get_fa().item()
                    vad_miss = vad_scores.get_miss().item()
                    vad_precision = vad_scores.get_precision().item()
                    vad_recall = vad_scores.get_recall().item()
                    vad_matt = vad_scores.get_matt().item()
                    vad_f1 = vad_scores.get_f1().item()
                    vad_tp = vad_scores.tp.item()
                    vad_tn = vad_scores.tn.item()
                    vad_fp = vad_scores.fp.item()
                    vad_fn = vad_scores.fn.item()

                    vod_fa = vod_scores.get_fa().item()
                    vod_miss = vod_scores.get_miss().item()
                    vod_precision = vod_scores.get_precision().item()
                    vod_recall = vod_scores.get_recall().item()
                    vod_matt = vod_scores.get_matt().item()
                    vod_f1 = vod_scores.get_f1().item()
                    vod_tp = vod_scores.tp.item()
                    vod_tn = vod_scores.tn.item()
                    vod_fp = vod_scores.fp.item()
                    vod_fn = vod_scores.fn.item()

                    count_fa = count_scores.get_accuracy().item()
                    count_miss = count_scores.get_miss().item()
                    count_precision = count_scores.get_precision().item()
                    count_recall = count_scores.get_recall().item()
                    count_matt = count_scores.get_matt().item()
                    count_f1 = count_scores.get_f1().item()
                    count_tp = count_scores.get_tp().item()
                    count_tn = count_scores.get_tn().item()
                    count_fp = count_scores.get_fp().item()
                    count_fn = count_scores.get_fn().item()

                    t.set_postfix(loss=tot_loss / (count + 1),
                                  vad_miss=vad_miss,
                                  vad_fa=vad_fa,
                                  vad_prec=vad_precision,
                                  vad_recall=vad_recall,
                                  vad_matt=vad_matt,
                                  vad_f1=vad_f1,
                                  vod_miss=vod_miss,
                                  vod_fa=vod_fa,
                                  vod_prec=vod_precision,
                                  vod_recall=vod_recall,
                                  vod_matt=vod_matt,
                                  vod_f1=vod_f1,
                                  count_miss=count_miss,
                                  count_fa=count_fa,
                                  count_prec=count_precision,
                                  count_recall=count_recall,
                                  count_matt=count_matt,
                                  count_f1=count_f1)
                    t.update()

            writer.log_metrics("train_vad", loss, vad_fa, vad_miss, vad_recall,
                               vad_precision, vad_f1, vad_matt, vad_tp, vad_tn,
                               vad_fp, vad_fn, epoch)
            writer.log_metrics("train_vod", loss, vod_fa, vod_miss, vod_recall,
                               vod_precision, vod_f1, vod_matt, vod_tp, vod_tn,
                               vod_fp, vod_fn, epoch)
            writer.log_metrics("train_count", loss, count_fa, count_miss,
                               count_recall, count_precision, count_f1,
                               count_matt, count_tp, count_tn, count_fp,
                               count_fn, epoch)
            # end epoch save model and validate it

            val_loss = validate(hp, model, devloader, writer, epoch)

            if hp.train.save_best == 0:
                save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % epoch)
                torch.save(
                    {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'step': epoch,
                        'hp_str': hp_str,
                    }, save_path)
                logger.info("Saved checkpoint to: %s" % save_path)

            else:
                if val_loss < best_loss:  # save only when best
                    best_loss = val_loss
                    save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % epoch)
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'step': epoch,
                            'hp_str': hp_str,
                        }, save_path)
                logger.info("Saved checkpoint to: %s" % save_path)

        return best_loss

    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()
Beispiel #5
0
        return data, label


device = "cuda"

data = torch.load("corpus.pt")
ds = DS(data['train']['src'], data['train']['label'], 1000)
train_data_loader = DataLoader(ds, batch_size=1000)

device_ids = [0, 7]

m = model(data['dict']['vocab_size'], data['dict']['label_size'], 1000)
m = m.to(device)

optimizer = torch.optim.Adam(m.parameters())
criterion = torch.nn.CrossEntropyLoss()
m = DataParallel(m, device_ids=device_ids)

if __name__ == "__main__":
    m.train()
    for _ in range(100):
        for data, label in train_data_loader:
            data = data.to(device)
            label = label.to(device)
            optimizer.zero_grad()
            target = m(data)

            loss = criterion(target, label)
            loss.backward()
            optimizer.step()