コード例 #1
0
    def __init__(self, model, cfg, logger_name=None, device=0, config=None):
        super().__init__(device=device)
        self.cfg = cfg
        self.model = model
        if self.cuda_available:
            self.model.cuda()

        self.criterion = get_loss_function(config)
        optimizer_cls = get_optimizer(config)
        optimizer_params = {
            k: v
            for k, v in config["training"]["optimizer"].items() if k != "name"
        }
        self.optimizer = optimizer_cls(self.model.parameters(),
                                       **optimizer_params)

        self.logger = logging.getLogger(logger_name)
        self.inference_img_size = 16
        self.reducer = nn.AdaptiveAvgPool2d(self.inference_img_size)
コード例 #2
0
def train(cfg, writer, logger):

    # Setup seeds
    # torch.manual_seed(cfg.get("seed", 1337))
    # torch.cuda.manual_seed(cfg.get("seed", 1337))
    # np.random.seed(cfg.get("seed", 1337))
    # random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])

            if not args.load_weight_only:
                model = DataParallel_withLoss(model, loss_fn)
                model.load_state_dict(checkpoint["model_state"])
                if not args.not_load_optimizer:
                    optimizer.load_state_dict(checkpoint["optimizer_state"])

                # !!!
                # checkpoint["scheduler_state"]['last_epoch'] = -1
                # scheduler.load_state_dict(checkpoint["scheduler_state"])
                # start_iter = checkpoint["epoch"]
                start_iter = 0
                # import ipdb
                # ipdb.set_trace()
                logger.info("Loaded checkpoint '{}' (iter {})".format(
                    cfg["training"]["resume"], checkpoint["epoch"]))
            else:
                pretrained_dict = convert_state_dict(checkpoint["model_state"])
                model_dict = model.state_dict()
                # 1. filter out unnecessary keys
                pretrained_dict = {
                    k: v
                    for k, v in pretrained_dict.items() if k in model_dict
                }
                # 2. overwrite entries in the existing state dict
                model_dict.update(pretrained_dict)
                # 3. load the new state dict
                model.load_state_dict(model_dict)
                model = DataParallel_withLoss(model, loss_fn)
                # import ipdb
                # ipdb.set_trace()
                # start_iter = -1
                logger.info(
                    "Loaded checkpoint '{}' (iter unknown, from pretrained icnet model)"
                    .format(cfg["training"]["resume"]))

        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, inst_labels) in trainloader:

            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            inst_labels = inst_labels.to(device)
            optimizer.zero_grad()

            loss, _, aux_info = model(labels,
                                      inst_labels,
                                      images,
                                      return_aux_info=True)
            loss = loss.sum()
            loss_sem = aux_info[0].sum()
            loss_inst = aux_info[1].sum()

            # loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f} (Sem:{:.4f}/Inst:{:.4f})  LR:{:.5f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    loss_sem.item(),
                    loss_inst.item(),
                    scheduler.get_lr()[0],
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                # print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:

                model.eval()

                with torch.no_grad():
                    for i_val, (images_val, labels_val,
                                inst_labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        inst_labels_val = inst_labels_val.to(device)
                        # outputs = model(images_val)
                        # val_loss = loss_fn(input=outputs, target=labels_val)
                        val_loss, (outputs, outputs_inst) = model(
                            labels_val, inst_labels_val, images_val)
                        val_loss = val_loss.sum()

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) % cfg["training"]["save_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_{:05d}_model.pkl".format(cfg["model"]["arch"],
                                                    cfg["data"]["dataset"],
                                                    i + 1),
                )
                torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
            i += 1
コード例 #3
0
ファイル: train.py プロジェクト: clarenceyapp/UnMICST-info
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # # Setup Augmentations
    # augmentations = cfg["training"].get("augmentations", None)
    # data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    v_loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
        drop_last=True,
    )

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
        drop_last=True,
    )

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model_orig = get_model(cfg["model"], n_classes).to(device)
    if cfg["training"]["pretrain"] == True:
        # Load a pretrained model
        model_orig.load_pretrained_model(
            model_path="pretrained/pspnet101_cityscapes.caffemodel")
        logger.info("Loaded pretrained model.")
    else:
        # No pretrained model
        logger.info("No pretraining.")

    model = torch.nn.DataParallel(model_orig,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    ### Visualize model training

    # helper function to show an image
    # (used in the `plot_classes_preds` function below)
    def matplotlib_imshow(data, is_image):
        if is_image:  #for images
            data = data / 4 + 0.5  # unnormalize
            npimg = data.numpy()
            plt.imshow(npimg, cmap="gray")
        else:  # for labels
            nplbl = data.numpy()
            plt.imshow(t_loader.decode_segmap(nplbl))

    def plot_classes_preds(data, batch_size, iter, is_image=True):
        fig = plt.figure(figsize=(12, 48))
        for idx in np.arange(batch_size):
            ax = fig.add_subplot(1, batch_size, idx + 1, xticks=[], yticks=[])
            matplotlib_imshow(data[idx], is_image)

            ax.set_title("Iteration Number " + str(iter))

        return fig

    best_iou = -100.0
    #best_val_loss = -100.0
    i = start_iter
    flag = True

    #Check if params trainable
    print('CHECK PARAMETER TRAINING:')
    for name, param in model.named_parameters():
        if param.requires_grad == False:
            print(name, param.data)

    while i <= cfg["training"]["train_iters"] and flag:
        for (images_orig, labels_orig, weights_orig,
             nuc_weights_orig) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()  #convert model into training mode
            images = images_orig.to(device)
            labels = labels_orig.to(device)
            weights = weights_orig.to(device)
            nuc_weights = nuc_weights_orig.to(device)

            optimizer.zero_grad()

            outputs = model(images)

            # Transform output to calculate meaningful loss
            out = outputs[0]

            # Resize output of network to same size as labels
            target_size = (labels.size()[1], labels.size()[2])
            out = torch.nn.functional.interpolate(out,
                                                  size=target_size,
                                                  mode='bicubic')

            # Multiply weights by loss output
            loss = loss_fn(input=out, target=labels)

            loss = torch.mul(loss, weights)  # add contour weights
            loss = torch.mul(loss, nuc_weights)  # add nuclei weights
            loss = loss.mean(
            )  # average over all pixels to obtain scaler for loss

            loss.backward()  # computes gradients over network
            optimizer.step()  #updates parameters

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"][
                    "print_interval"] == 0:  # frequency with which visualize training update
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )
                #Show mini-batches during training

                # #Visualize only DAPI
                # writer.add_figure('Inputs',
                #     plot_classes_preds(images_orig.squeeze(), cfg["training"]["batch_size"], i, True),
                #             global_step=i)

                # writer.add_figure('Targets',
                #     plot_classes_preds(labels_orig, cfg["training"]["batch_size"], i, False),
                #             global_step=i)

                #Take max across classes (of probability maps) and assign class label to visualize semantic map
                #1)
                out_orig = torch.nn.functional.softmax(
                    outputs[0], dim=1).max(1).indices.cpu()
                #out_orig = out_orig.cpu().detach()
                #2)
                #out_orig = torch.argmax(outputs[0],dim=1)
                #3)
                #out_orig = outputs[0].data.max(1)[1].cpu()

                # #Visualize predictions
                # writer.add_figure('Predictions',
                #     plot_classes_preds(out_orig, cfg["training"]["batch_size"], i, False),
                #             global_step=i)

                #Save probability map
                prob_maps_folder = os.path.join(
                    writer.file_writer.get_logdir(), "probability_maps")
                os.makedirs(prob_maps_folder, exist_ok=True)

                #Downsample original images to target size for visualization
                images = torch.nn.functional.interpolate(images,
                                                         size=target_size,
                                                         mode='bicubic')

                out = torch.nn.functional.softmax(out, dim=1)

                contours = (out[:, 1, :, :]).unsqueeze(dim=1)
                nuclei = (out[:, 2, :, :]).unsqueeze(dim=1)
                background = (out[:, 0, :, :]).unsqueeze(dim=1)

                #imageTensor = torch.cat((images, contours, nuclei, background),dim=0)

                # Save images side by side: nrow is how many images per row
                #save_image(make_grid(imageTensor, nrow=2), os.path.join(prob_maps_folder,"Prob_maps_%d.tif" % i))

                # Targets visualization below
                nplbl = labels_orig.numpy()
                targets = []  #each element is RGB target label in batch
                for bs in np.arange(cfg["training"]["batch_size"]):
                    target_bs = t_loader.decode_segmap(nplbl[bs])
                    target_bs = 255 * target_bs
                    target_bs = target_bs.astype('uint8')
                    target_bs = torch.from_numpy(target_bs)
                    target_bs = target_bs.unsqueeze(dim=0)
                    targets.append(target_bs)  #uint8 labels, shape (N,N,3)

                target = reduce(lambda x, y: torch.cat((x, y), dim=0), targets)
                target = target.permute(0, 3, 1,
                                        2)  # size=(Batch, Channels, N, N)
                target = target.type(torch.FloatTensor)

                save_image(
                    make_grid(target, nrow=cfg["training"]["batch_size"]),
                    os.path.join(prob_maps_folder, "Target_labels_%d.tif" % i))

                # Weights visualization below:
                #wgts = weights_orig.type(torch.FloatTensor)
                #save_image(make_grid(wgts, nrow=2), os.path.join(prob_maps_folder,"Weights_%d.tif" % i))

                # Probability maps visualization below
                t1 = []
                t2 = []
                t3 = []
                t4 = []

                # Normalize individual images in batch
                for bs in np.arange(cfg["training"]["batch_size"]):
                    t1.append((images[bs][0] - images[bs][0].min()) /
                              (images[bs][0].max() - images[bs][0].min()))
                    t2.append(contours[bs])
                    t3.append(nuclei[bs])
                    t4.append(background[bs])

                t1 = [torch.unsqueeze(elem, dim=0)
                      for elem in t1]  #expand dim=0 for images in batch
                # Convert normalized batch to Tensor
                tensor1 = torch.cat((t1), dim=0)
                tensor2 = torch.cat((t2), dim=0)
                tensor3 = torch.cat((t3), dim=0)
                tensor4 = torch.cat((t4), dim=0)

                tTensor = torch.cat((tensor1, tensor2, tensor3, tensor4),
                                    dim=0)
                tTensor = tTensor.unsqueeze(dim=1)

                save_image(make_grid(tTensor,
                                     nrow=cfg["training"]["batch_size"]),
                           os.path.join(prob_maps_folder,
                                        "Prob_maps_%d.tif" % i),
                           normalize=False)

                logger.info(print_str)
                writer.add_scalar(
                    "loss/train_loss", loss.item(),
                    i + 1)  # adds value to history (title, loss, iter index)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1
            ) == cfg["training"][
                    "train_iters"]:  # evaluate model on validation set at these intervals
                model.eval()  # evaluate mode for model
                with torch.no_grad():
                    for i_val, (images_val, labels_val, weights_val,
                                nuc_weights_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        weights_val = weights_val.to(device)
                        nuc_weights_val = nuc_weights_val.to(device)

                        outputs_val = model(images_val)

                        # Resize output of network to same size as labels
                        target_val_size = (labels_val.size()[1],
                                           labels_val.size()[2])
                        outputs_val = torch.nn.functional.interpolate(
                            outputs_val, size=target_val_size, mode='bicubic')

                        # Multiply weights by loss output
                        val_loss = loss_fn(input=outputs_val,
                                           target=labels_val)

                        val_loss = torch.mul(val_loss, weights_val)
                        val_loss = torch.mul(val_loss, nuc_weights_val)
                        val_loss = val_loss.mean(
                        )  # average over all pixels to obtain scaler for loss

                        outputs_val = torch.nn.functional.softmax(outputs_val,
                                                                  dim=1)

                        #Save probability map
                        val_prob_maps_folder = os.path.join(
                            writer.file_writer.get_logdir(),
                            "val_probability_maps")
                        os.makedirs(val_prob_maps_folder, exist_ok=True)

                        #Downsample original images to target size for visualization
                        images_val = torch.nn.functional.interpolate(
                            images_val, size=target_val_size, mode='bicubic')

                        contours_val = (outputs_val[:,
                                                    1, :, :]).unsqueeze(dim=1)
                        nuclei_val = (outputs_val[:, 2, :, :]).unsqueeze(dim=1)
                        background_val = (outputs_val[:, 0, :, :]).unsqueeze(
                            dim=1)

                        # Targets visualization below
                        nplbl_val = labels_val.cpu().numpy()
                        targets_val = [
                        ]  #each element is RGB target label in batch
                        for bs in np.arange(cfg["training"]["batch_size"]):
                            target_bs = v_loader.decode_segmap(nplbl_val[bs])
                            target_bs = 255 * target_bs
                            target_bs = target_bs.astype('uint8')
                            target_bs = torch.from_numpy(target_bs)
                            target_bs = target_bs.unsqueeze(dim=0)
                            targets_val.append(
                                target_bs)  #uint8 labels, shape (N,N,3)

                        target_val = reduce(
                            lambda x, y: torch.cat((x, y), dim=0), targets_val)
                        target_val = target_val.permute(
                            0, 3, 1, 2)  # size=(Batch, Channels, N, N)
                        target_val = target_val.type(torch.FloatTensor)

                        save_image(
                            make_grid(target_val,
                                      nrow=cfg["training"]["batch_size"]),
                            os.path.join(
                                val_prob_maps_folder,
                                "Target_labels_%d_val_%d.tif" % (i, i_val)))

                        # Weights visualization below:
                        #wgts_val = weights_val.type(torch.FloatTensor)
                        #save_image(make_grid(wgts_val, nrow=2), os.path.join(val_prob_maps_folder,"Weights_val_%d.tif" % i_val))

                        # Probability maps visualization below
                        t1_val = []
                        t2_val = []
                        t3_val = []
                        t4_val = []
                        # Normalize individual images in batch
                        for bs in np.arange(cfg["training"]["batch_size"]):
                            t1_val.append(
                                (images_val[bs][0] - images_val[bs][0].min()) /
                                (images_val[bs][0].max() -
                                 images_val[bs][0].min()))
                            t2_val.append(contours_val[bs])
                            t3_val.append(nuclei_val[bs])
                            t4_val.append(background_val[bs])

                        t1_val = [
                            torch.unsqueeze(elem, dim=0) for elem in t1_val
                        ]  #expand dim=0 for images_val in batch
                        # Convert normalized batch to Tensor
                        tensor1_val = torch.cat((t1_val), dim=0)
                        tensor2_val = torch.cat((t2_val), dim=0)
                        tensor3_val = torch.cat((t3_val), dim=0)
                        tensor4_val = torch.cat((t4_val), dim=0)

                        tTensor_val = torch.cat((tensor1_val, tensor2_val,
                                                 tensor3_val, tensor4_val),
                                                dim=0)
                        tTensor_val = tTensor_val.unsqueeze(dim=1)

                        save_image(make_grid(
                            tTensor_val, nrow=cfg["training"]["batch_size"]),
                                   os.path.join(
                                       val_prob_maps_folder,
                                       "Prob_maps_%d_val_%d.tif" % (i, i_val)),
                                   normalize=False)

                        pred = outputs_val.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                ### Save best validation loss model
                # if val_loss_meter.avg >= best_val_loss:
                #     best_val_loss = val_loss_meter.avg
                #     state = {
                #         "epoch": i + 1,
                #         "model_state": model.state_dict(),
                #         "optimizer_state": optimizer.state_dict(),
                #         "scheduler_state": scheduler.state_dict(),
                #         "best_val_loss": best_val_loss,
                #     }
                #     save_path = os.path.join(
                #         writer.file_writer.get_logdir(),
                #         "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                #     )
                #     torch.save(state, save_path)
                ###

                score, class_iou = running_metrics_val.get_scores(
                )  # best model chosen via IoU
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                ### Save best mean IoU model
                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
                ###

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
コード例 #4
0
ファイル: ohio_train.py プロジェクト: OpenGeoscience/deepres
def train(cfg, writer, logger):
    
    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
#    data_loader = get_loader(cfg['data']['dataset'])
#    data_path = cfg['data']['path']
#
#    t_loader = data_loader(
#        data_path,
#        is_transform=True,
#        split=cfg['data']['train_split'],
#        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
#        augmentations=data_aug)
#
#    v_loader = data_loader(
#        data_path,
#        is_transform=True,
#        split=cfg['data']['val_split'],
#        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)
#
#    n_classes = t_loader.n_classes
#    trainloader = data.DataLoader(t_loader,
#                                  batch_size=cfg['training']['batch_size'], 
#                                  num_workers=cfg['training']['n_workers'], 
#                                  shuffle=True)
#
#    valloader = data.DataLoader(v_loader, 
#                                batch_size=cfg['training']['batch_size'], 
#                                num_workers=cfg['training']['n_workers'])

    paths = {
        'masks': './satellitedata/patchohio_train/gt/',
        'images': './satellitedata/patchohio_train/rgb',
        'nirs': './satellitedata/patchohio_train/nir',
        'swirs': './satellitedata/patchohio_train/swir',
        'vhs': './satellitedata/patchohio_train/vh',
        'vvs': './satellitedata/patchohio_train/vv',
        'redes': './satellitedata/patchohio_train/rede',
        'ndvis': './satellitedata/patchohio_train/ndvi',
        }

    valpaths = {
        'masks': './satellitedata/patchohio_val/gt/',
        'images': './satellitedata/patchohio_val/rgb',
        'nirs': './satellitedata/patchohio_val/nir',
        'swirs': './satellitedata/patchohio_val/swir',
        'vhs': './satellitedata/patchohio_val/vh',
        'vvs': './satellitedata/patchohio_val/vv',
        'redes': './satellitedata/patchohio_val/rede',
        'ndvis': './satellitedata/patchohio_val/ndvi',
        }
  
  
    n_classes = 3
    train_img_paths = [pth for pth in os.listdir(paths['images']) if ('_01_' not in pth) and ('_25_' not in pth)]
    val_img_paths = [pth for pth in os.listdir(valpaths['images']) if ('_01_' not in pth) and ('_25_' not in pth)]
    ntrain = len(train_img_paths)
    nval = len(val_img_paths)
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    trainds = ImageProvider(MultibandImageType, paths, image_suffix='.png')
    valds = ImageProvider(MultibandImageType, valpaths, image_suffix='.png')
    
    config_path = 'crop_pspnet_config.json'
    with open(config_path, 'r') as f:
        mycfg = json.load(f)
        train_data_path = './satellitedata/'
        print('train_data_path: {}'.format(train_data_path))
        dataset_path, train_dir = os.path.split(train_data_path)
        print('dataset_path: {}'.format(dataset_path) + ',  train_dir: {}'.format(train_dir))
        mycfg['dataset_path'] = dataset_path
    config = Config(**mycfg)

    config = update_config(config, num_channels=12, nb_epoch=50)
    #dataset_train = TrainDataset(trainds, train_idx, config, transforms=augment_flips_color)
    dataset_train = TrainDataset(trainds, train_idx, config, 1)
    dataset_val = TrainDataset(valds, val_idx, config, 1)
    trainloader = data.DataLoader(dataset_train,
                                  batch_size=cfg['training']['batch_size'], 
                                  num_workers=cfg['training']['n_workers'], 
                                  shuffle=True)

    valloader = data.DataLoader(dataset_val,
                                  batch_size=cfg['training']['batch_size'], 
                                  num_workers=cfg['training']['n_workers'], 
                                  shuffle=False)
    # Setup Metrics
    running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    k = 0
    nbackground = 0
    ncorn = 0
    #ncotton = 0
    #nrice = 0
    nsoybean = 0


    for indata in trainloader:
        k += 1
        gt = indata['seg_label'].data.cpu().numpy()
        nbackground += (gt == 0).sum()
        ncorn += (gt == 1).sum()
        #ncotton += (gt == 2).sum()
        #nrice += (gt == 3).sum()
        nsoybean += (gt == 2).sum()

    print('k = {}'.format(k))
    print('nbackgraound: {}'.format(nbackground))
    print('ncorn: {}'.format(ncorn))
    #print('ncotton: {}'.format(ncotton))
    #print('nrice: {}'.format(nrice))
    print('nsoybean: {}'.format(nsoybean))
    
    wgts = [1.0, 1.0*nbackground/ncorn, 1.0*nbackground/nsoybean]
    total_wgts = sum(wgts)
    wgt_background = wgts[0]/total_wgts
    wgt_corn = wgts[1]/total_wgts
    #wgt_cotton = wgts[2]/total_wgts
    #wgt_rice = wgts[3]/total_wgts
    wgt_soybean = wgts[2]/total_wgts
    weights = torch.autograd.Variable(torch.cuda.FloatTensor([wgt_background, wgt_corn, wgt_soybean]))

    #weights = torch.autograd.Variable(torch.cuda.FloatTensor([1.0, 1.0, 1.0]))
    

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() 
                        if k != 'name'}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for inputdata in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = inputdata['img_data']
            labels = inputdata['seg_label']
            #print('images.size: {}'.format(images.size()))
            #print('labels.size: {}'.format(labels.size()))
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            #print('outputs.size: {}'.format(outputs[1].size()))
            #print('labels.size: {}'.format(labels.size()))

            loss = loss_fn(input=outputs[1], target=labels, weight=weights)

            loss.backward()
            optimizer.step()
            
            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i + 1,
                                           cfg['training']['train_iters'], 
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i+1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for inputdata in valloader:
                        images_val = inputdata['img_data']
                        labels_val = inputdata['seg_label']
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()


                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(),
                                             "{}_{}_best_model.pkl".format(
                                                 cfg['model']['arch'],
                                                 cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
コード例 #5
0
ファイル: train.py プロジェクト: templeblock/TDNet
def train(cfg, logger, logdir):
    # Setup seeds
    init_seed(11733, en_cudnn=False)

    # Setup Augmentations
    train_augmentations = cfg["training"].get("train_augmentations", None)
    t_data_aug = get_composed_augmentations(train_augmentations)
    val_augmentations = cfg["validating"].get("val_augmentations", None)
    v_data_aug = get_composed_augmentations(val_augmentations)

    # Setup Dataloader

    path_n = cfg["model"]["path_num"]

    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(data_path,split=cfg["data"]["train_split"],augmentations=t_data_aug,path_num=path_n)
    v_loader = data_loader(data_path,split=cfg["data"]["val_split"],augmentations=v_data_aug,path_num=path_n)

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg["training"]["batch_size"],
                                  num_workers=cfg["training"]["n_workers"],
                                  shuffle=True,
                                  drop_last=True  )
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["validating"]["batch_size"],
                                num_workers=cfg["validating"]["n_workers"] )

    logger.info("Using training seting {}".format(cfg["training"]))
    
    # Setup Metrics
    running_metrics_val = runningScore(t_loader.n_classes)

    # Setup Model and Loss
    loss_fn = get_loss_function(cfg["training"])
    teacher = get_model(cfg["teacher"], t_loader.n_classes)
    model = get_model(cfg["model"],t_loader.n_classes, loss_fn, cfg["training"]["resume"],teacher)
    logger.info("Using loss {}".format(loss_fn))

    # Setup optimizer
    optimizer = get_optimizer(cfg["training"], model)

    # Setup Multi-GPU
    model = DataParallelModel(model).cuda()

    #Initialize training param
    cnt_iter = 0
    best_iou = 0.0
    time_meter = averageMeter()

    while cnt_iter <= cfg["training"]["train_iters"]:
        for (f_img, labels) in trainloader:
            cnt_iter += 1
            model.train()
            optimizer.zero_grad()

            start_ts = time.time()
            outputs = model(f_img,labels,pos_id=cnt_iter%path_n)

            seg_loss = gather(outputs, 0)
            seg_loss = torch.mean(seg_loss)

            seg_loss.backward()
            time_meter.update(time.time() - start_ts)

            optimizer.step()

            if (cnt_iter + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                                            cnt_iter + 1,
                                            cfg["training"]["train_iters"],
                                            seg_loss.item(),
                                            time_meter.avg / cfg["training"]["batch_size"], )

                print(print_str)
                logger.info(print_str)
                time_meter.reset()

            if (cnt_iter + 1) % cfg["training"]["val_interval"] == 0 or (cnt_iter + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (f_img_val, labels_val) in tqdm(enumerate(valloader)):
                        
                        outputs = model(f_img_val,pos_id=i_val%path_n)
                        outputs = gather(outputs, 0, dim=0)
                        
                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))

                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": cnt_iter + 1,
                        "model_state": clean_state_dict(model.module.state_dict(),'teacher'),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(logdir,
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
コード例 #6
0
ファイル: train.py プロジェクト: GibranBenitez/FASSD-Net
def train(cfg, writer, logger, args):
    # cfg

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(1024, 2048),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = FASSDNet(n_classes=19, alpha=args.alpha).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print('Parameters:', total_params)
    model.apply(weights_init)

    # Non-strict ImageNet pre-train
    pretrained_path = 'weights/imagenet_weights.pth'
    checkpoint = torch.load(pretrained_path)
    q = 1
    model_dict = {}
    state_dict = model.state_dict()

    # print('================== Weights orig: ', model.base[1].conv.weight[0][0][0])
    for k, v in checkpoint.items():
        if q == 1:
            # print("===> Key of checkpoint: ", k)
            # print("===> Value of checkpoint: ", v[0][0][0])
            if ('base.' + k in state_dict):
                # print("============> CONTAINS KEY...")
                # print("===> Value of the key: ", state_dict['base.'+k][0][0][0])
                pass

            else:
                # print("============> DOES NOT CONTAIN KEY...")
                pass
            q = 0

        if ('base.' + k in state_dict) and (state_dict['base.' + k].shape
                                            == checkpoint[k].shape):
            model_dict['base.' + k] = v

    state_dict.update(model_dict)  # Updated weights with ImageNet pretraining
    model.load_state_dict(state_dict)
    # print('================== Weights loaded: ', model.base[0].conv.weight[0][0][0])

    # Multi-gpu model
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    print("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    print("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):

            print_str = "Finetuning model from '{}'".format(
                cfg["training"]["finetune"])
            if logger is not None:
                logger.info(print_str)
            print(print_str)

            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]

            print_str = "Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"])
            print(print_str)
            if logger is not None:
                logger.info(print_str)
        else:
            print_str = "No checkpoint found at '{}'".format(
                cfg["training"]["resume"])
            print(print_str)
            if logger is not None:
                logger.info(print_str)

    if cfg["training"]["finetune"] is not None:
        if os.path.isfile(cfg["training"]["finetune"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["finetune"]))
            checkpoint = torch.load(cfg["training"]["finetune"])
            model.load_state_dict(checkpoint["model_state"])

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True
    loss_all = 0
    loss_n = 0
    sys.stdout.flush()

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, _) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)
            loss.backward()
            optimizer.step()
            c_lr = scheduler.get_lr()

            time_meter.update(time.time() - start_ts)
            loss_all += loss.item()
            loss_n += 1

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}  lr={:.6f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss_all / loss_n,
                    time_meter.avg / cfg["training"]["batch_size"],
                    c_lr[0],
                )

                print(print_str)
                if logger is not None:
                    logger.info(print_str)

                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                torch.cuda.empty_cache()
                model.eval()
                loss_all = 0
                loss_n = 0
                with torch.no_grad():
                    # for i_val, (images_val, labels_val, _) in tqdm(enumerate(valloader)):
                    for i_val, (images_val, labels_val,
                                _) in enumerate(valloader):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)

                print_str = "Iter %d Val Loss: %.4f" % (i + 1,
                                                        val_loss_meter.avg)
                if logger is not None:
                    logger.info(print_str)
                print(print_str)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print_str = "{}: {}".format(k, v)
                    if logger is not None:
                        logger.info(print_str)
                    print(print_str)

                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    print_str = "{}: {}".format(k, v)
                    if logger is not None:
                        logger.info(print_str)
                    print(print_str)

                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_checkpoint.pkl".format(cfg["model"]["arch"],
                                                  cfg["data"]["dataset"]),
                )
                torch.save(state, save_path)

                if score["Mean IoU : \t"] >= best_iou:  # Save best model (mIoU)
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
                torch.cuda.empty_cache()

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
            sys.stdout.flush()  # Added
コード例 #7
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    if not 'fold' in cfg['data'].keys():
        cfg['data']['fold'] = None

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']],
        augmentations=data_aug,
        fold=cfg['data']['fold'],
        n_classes=cfg['data']['n_classes'])

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']],
        fold=cfg['data']['fold'],
        n_classes=cfg['data']['n_classes'])

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=1,
                                num_workers=cfg['training']['n_workers'])

    logger.info("Training on fold {}".format(cfg['data']['fold']))
    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)
    if args.model_path != "fcn8s_pascal_1_26.pkl": # Default Value
        state = convert_state_dict(torch.load(args.model_path)["model_state"])
        if cfg['model']['use_scale']:
            model = load_my_state_dict(model, state)
            model.freeze_weights_extractor()
        else:
            model.load_state_dict(state)
            model.freeze_weights_extractor()

    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))


    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items()
                        if k != 'name'}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
#            import matplotlib.pyplot as plt
#            plt.figure(1);plt.imshow(np.transpose(images[0], (1,2,0)));plt.figure(2); plt.imshow(labels[0]); plt.show()

            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i + 1,
                                           cfg['training']['train_iters'],
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i+1)
                time_meter.reset()

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(writer.file_writer.get_logdir(),
                                         "{}_{}_best_model.pkl".format(
                                             cfg['model']['arch'],
                                             cfg['data']['dataset']))
                torch.save(state, save_path)
                break
コード例 #8
0
def test(cfg, areaname):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    #    data_loader = get_loader(cfg['data']['dataset'])
    #    data_path = cfg['data']['path']
    #
    #    t_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['train_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    #        augmentations=data_aug)
    #
    #    v_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['val_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)
    #
    #    n_classes = t_loader.n_classes
    #    trainloader = data.DataLoader(t_loader,
    #                                  batch_size=cfg['training']['batch_size'],
    #                                  num_workers=cfg['training']['n_workers'],
    #                                  shuffle=True)
    #
    #    valloader = data.DataLoader(v_loader,
    #                                batch_size=cfg['training']['batch_size'],
    #                                num_workers=cfg['training']['n_workers'])
    datapath = '/home/chengjjang/Projects/deepres/SatelliteData/{}/'.format(
        areaname)
    paths = {
        'masks': '{}/patch{}_train/gt'.format(datapath, areaname),
        'images': '{}/patch{}_train/rgb'.format(datapath, areaname),
        'nirs': '{}/patch{}_train/nir'.format(datapath, areaname),
        'swirs': '{}/patch{}_train/swir'.format(datapath, areaname),
        'vhs': '{}/patch{}_train/vh'.format(datapath, areaname),
        'vvs': '{}/patch{}_train/vv'.format(datapath, areaname),
        'redes': '{}/patch{}_train/rede'.format(datapath, areaname),
        'ndvis': '{}/patch{}_train/ndvi'.format(datapath, areaname),
    }

    valpaths = {
        'masks': '{}/patch{}_val/gt'.format(datapath, areaname),
        'images': '{}/patch{}_val/rgb'.format(datapath, areaname),
        'nirs': '{}/patch{}_val/nir'.format(datapath, areaname),
        'swirs': '{}/patch{}_val/swir'.format(datapath, areaname),
        'vhs': '{}/patch{}_val/vh'.format(datapath, areaname),
        'vvs': '{}/patch{}_val/vv'.format(datapath, areaname),
        'redes': '{}/patch{}_val/rede'.format(datapath, areaname),
        'ndvis': '{}/patch{}_val/ndvi'.format(datapath, areaname),
    }

    n_classes = 3
    train_img_paths = [
        pth for pth in os.listdir(paths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    val_img_paths = [
        pth for pth in os.listdir(valpaths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    ntrain = len(train_img_paths)
    nval = len(val_img_paths)
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    trainds = ImageProvider(MultibandImageType, paths, image_suffix='.png')
    valds = ImageProvider(MultibandImageType, valpaths, image_suffix='.png')

    print('valds.im_names: {}'.format(valds.im_names))

    config_path = 'crop_pspnet_config.json'
    with open(config_path, 'r') as f:
        mycfg = json.load(f)
        train_data_path = '{}/patch{}_train'.format(datapath, areaname)
        dataset_path, train_dir = os.path.split(train_data_path)
        mycfg['dataset_path'] = dataset_path
    config = Config(**mycfg)

    config = update_config(config, num_channels=12, nb_epoch=50)
    #dataset_train = TrainDataset(trainds, train_idx, config, transforms=augment_flips_color)
    dataset_train = TrainDataset(trainds, train_idx, config, 1)
    dataset_val = ValDataset(valds, val_idx, config, 1)
    trainloader = data.DataLoader(dataset_train,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(dataset_val,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'],
                                shuffle=False)
    # Setup Metrics
    running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    nbackground = 1116403140
    ncorn = 44080178
    nsoybean = 316698122

    print('nbackgraound: {}'.format(nbackground))
    print('ncorn: {}'.format(ncorn))
    print('nsoybean: {}'.format(nsoybean))

    wgts = [1.0, 1.0 * nbackground / ncorn, 1.0 * nbackground / nsoybean]
    total_wgts = sum(wgts)
    wgt_background = wgts[0] / total_wgts
    wgt_corn = wgts[1] / total_wgts
    wgt_soybean = wgts[2] / total_wgts
    weights = torch.autograd.Variable(
        torch.cuda.FloatTensor([wgt_background, wgt_corn, wgt_soybean]))

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)

    start_iter = 0
    runpath = '/home/chengjjang/arisia/CropPSPNet/runs/pspnet_crop_{}'.format(
        areaname)
    modelpath = glob.glob('{}/*/*_best_model.pkl'.format(runpath))[0]
    print('modelpath: {}'.format(modelpath))
    checkpoint = torch.load(modelpath)
    model.load_state_dict(checkpoint["model_state"])

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0

    respath = '{}_results_val'.format(areaname)
    os.makedirs(respath, exist_ok=True)

    model.eval()
    with torch.no_grad():
        for inputdata in valloader:
            imname_val = inputdata['img_name']
            images_val = inputdata['img_data']
            labels_val = inputdata['seg_label']
            images_val = images_val.to(device)
            labels_val = labels_val.to(device)

            print('imname_val: {}'.format(imname_val))

            outputs = model(images_val)
            val_loss = loss_fn(input=outputs, target=labels_val)

            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()

            dname = imname_val[0].split('.png')[0]
            np.save('{}/pred'.format(respath) + dname + '.npy', pred)
            np.save('{}/gt'.format(respath) + dname + '.npy', gt)
            np.save('{}/output'.format(respath) + dname + '.npy',
                    outputs.data.cpu().numpy())

            running_metrics_val.update(gt, pred)
            val_loss_meter.update(val_loss.item())

    #writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
    #logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
    print('Test loss: {}'.format(val_loss_meter.avg))

    score, class_iou = running_metrics_val.get_scores()
    for k, v in score.items():
        print('val_metrics, {}: {}'.format(k, v))

    for k, v in class_iou.items():
        print('val_metrics, {}: {}'.format(k, v))

    val_loss_meter.reset()
    running_metrics_val.reset()
コード例 #9
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataloader_type"])

    data_root = cfg["data"]["data_root"]
    presentation_root = cfg["data"]["presentation_root"]

    t_loader = data_loader(
        data_root=data_root,
        presentation_root=presentation_root,
        is_transform=True,
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(data_root=data_root,
                           presentation_root=presentation_root,
                           is_transform=True,
                           img_size=(cfg["data"]["img_rows"],
                                     cfg["data"]["img_cols"]),
                           augmentations=data_aug,
                           test_mode=True)

    n_classes = t_loader.n_classes

    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=False,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"],
                                shuffle=False)

    # Setup Metrics
    # running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes, defaultParams).to(device)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    model.load_pretrained_weights(cfg["training"]["saved_model_path"])

    # train_loss_meter = averageMeter()
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter

    while i <= cfg["training"]["num_presentations"]:

        #                #
        # TRAINING PHASE #
        #                #
        i += 1
        start_ts = time.time()
        trainloader.dataset.random_select()

        hebb = model.initialZeroHebb().to(device)
        for idx, (images, labels) in enumerate(
                trainloader, 1):  # get a single training presentation

            images = images.to(device)
            labels = labels.to(device)

            if idx <= 5:
                model.eval()
                with torch.no_grad():
                    outputs, hebb = model(images,
                                          labels,
                                          hebb,
                                          device,
                                          test_mode=False)
            else:
                scheduler.step()
                model.train()
                optimizer.zero_grad()
                outputs, hebb = model(images,
                                      labels,
                                      hebb,
                                      device,
                                      test_mode=True)
                loss = loss_fn(input=outputs, target=labels)
                loss.backward()
                optimizer.step()

        time_meter.update(time.time() -
                          start_ts)  # -> time taken per presentation

        if (i + 1) % cfg["training"]["print_interval"] == 0:
            fmt_str = "Pres [{:d}/{:d}]  Loss: {:.4f}  Time/Pres: {:.4f}"
            print_str = fmt_str.format(
                i + 1,
                cfg["training"]["num_presentations"],
                loss.item(),
                time_meter.avg / cfg["training"]["batch_size"],
            )
            print(print_str)
            logger.info(print_str)
            writer.add_scalar("loss/test_loss", loss.item(), i + 1)
            time_meter.reset()

        #            #
        # TEST PHASE #
        #            #
        if ((i + 1) % cfg["training"]["test_interval"] == 0
                or (i + 1) == cfg["training"]["num_presentations"]):

            training_state_dict = model.state_dict(
            )  # saving the training state of the model

            valloader.dataset.random_select()
            hebb = model.initialZeroHebb().to(device)
            for idx, (images_val, labels_val) in enumerate(
                    valloader, 1):  # get a single test presentation

                images_val = images_val.to(device)
                labels_val = labels_val.to(device)

                if idx <= 5:
                    model.eval()
                    with torch.no_grad():
                        outputs, hebb = model(images_val,
                                              labels_val,
                                              hebb,
                                              device,
                                              test_mode=False)
                else:
                    model.train()
                    optimizer.zero_grad()
                    outputs, hebb = model(images_val,
                                          labels_val,
                                          hebb,
                                          device,
                                          test_mode=True)
                    loss = loss_fn(input=outputs, target=labels_val)
                    loss.backward()
                    optimizer.step()

                    pred = outputs.data.max(1)[1].cpu().numpy()
                    gt = labels_val.data.cpu().numpy()

                    running_metrics_val.update(gt, pred)
                    val_loss_meter.update(loss.item())

            model.load_state_dict(
                training_state_dict)  # revert back to training parameters

            writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
            logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

            score, class_iou = running_metrics_val.get_scores()
            for k, v in score.items():
                print(k, v)
                logger.info("{}: {}".format(k, v))
                writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

            for k, v in class_iou.items():
                logger.info("{}: {}".format(k, v))
                writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

            val_loss_meter.reset()
            running_metrics_val.reset()

            if score["Mean IoU : \t"] >= best_iou:
                best_iou = score["Mean IoU : \t"]
                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_best_model.pkl".format(
                        cfg["model"]["arch"], cfg["data"]["dataloader_type"]),
                )
                torch.save(state, save_path)

        if (i + 1) == cfg["training"]["num_presentations"]:
            break
コード例 #10
0
def train(cfg, writer, logger):
    # Setup seeds
    init_seed(11733, en_cudnn=False)

    # Setup Augmentations
    train_augmentations = cfg["training"].get("train_augmentations", None)
    t_data_aug = get_composed_augmentations(train_augmentations)
    val_augmentations = cfg["validating"].get("val_augmentations", None)
    v_data_aug = get_composed_augmentations(val_augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])

    t_loader = data_loader(cfg=cfg["data"],
                           mode='train',
                           augmentations=t_data_aug)
    v_loader = data_loader(cfg=cfg["data"],
                           mode='val',
                           augmentations=v_data_aug)

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg["training"]["batch_size"],
                                  num_workers=cfg["training"]["n_workers"],
                                  shuffle=True,
                                  drop_last=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["validating"]["batch_size"],
                                num_workers=cfg["validating"]["n_workers"])

    logger.info("Using training seting {}".format(cfg["training"]))

    # Setup Metrics
    running_metrics_val = runningScore(t_loader.n_classes,
                                       t_loader.unseen_classes)

    model_state = torch.load(
        './runs/deeplabv3p_ade_25unseen/84253/deeplabv3p_ade20k_best_model.pkl'
    )
    running_metrics_val.confusion_matrix = model_state['results']
    score, a_iou = running_metrics_val.get_scores()

    pdb.set_trace()
    # Setup Model and Loss
    loss_fn = get_loss_function(cfg["training"])
    logger.info("Using loss {}".format(loss_fn))
    model = get_model(cfg["model"], t_loader.n_classes, loss_fn=loss_fn)

    # Setup optimizer
    optimizer = get_optimizer(cfg["training"], model)

    # Initialize training param
    start_iter = 0
    best_iou = -100.0

    # Resume from checkpoint
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info("Resuming training from checkpoint '{}'".format(
                cfg["training"]["resume"]))
            model_state = torch.load(cfg["training"]["resume"])["model_state"]
            model.load_state_dict(model_state)
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    # Setup Multi-GPU
    if torch.cuda.is_available():
        model = model.cuda()  # DataParallelModel(model).cuda()
        logger.info("Model initialized on GPUs.")

    time_meter = averageMeter()
    i = start_iter

    embd = t_loader.embeddings
    ignr_idx = t_loader.ignore_index
    embds = embd.cuda()
    while i <= cfg["training"]["train_iters"]:
        for (images, labels) in trainloader:
            images = images.cuda()
            labels = labels.cuda()

            i += 1
            model.train()
            optimizer.zero_grad()

            start_ts = time.time()
            loss_sum = model(images, labels, embds, ignr_idx)
            if loss_sum == 0:  # Ignore samples contain unseen cat
                continue  # To enable non-transductive learning, set transductive=0 in the config

            loss_sum.backward()

            time_meter.update(time.time() - start_ts)

            optimizer.step()

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss_sum.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss_sum.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.cuda()
                        labels_val = labels_val.cuda()
                        outputs = model(images_val, labels_val, embds,
                                        ignr_idx)
                        # outputs = gather(outputs, 0, dim=0)

                        running_metrics_val.update(outputs)

                score, a_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print("{}: {}".format(k, v))
                    logger.info("{}: {}".format(k, v))
                    #writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                #for k, v in class_iou.items():
                #    logger.info("{}: {}".format(k, v))
                #    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                if a_iou >= best_iou:
                    best_iou = a_iou
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "best_iou": best_iou,
                        "results": running_metrics_val.confusion_matrix
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

                running_metrics_val.reset()
コード例 #11
0
ファイル: train.py プロジェクト: templeblock/recurrent-unet
def train(cfg, writer, logger, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device('cuda')

    # Setup Augmentations
    # augmentations = cfg['training'].get('augmentations', None)
    if cfg['data']['dataset'] in ['cityscapes']:
        augmentations = cfg['training'].get(
            'augmentations', {
                'brightness': 63. / 255.,
                'saturation': 0.5,
                'contrast': 0.8,
                'hflip': 0.5,
                'rotate': 10,
                'rscalecropsquare': 713,
            })
        # augmentations = cfg['training'].get('augmentations',
        #                                     {'rotate': 10, 'hflip': 0.5, 'rscalecrop': 512, 'gaussian': 0.5})
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes, args).to(device)
    model.apply(weights_init)
    print('sleep for 5 seconds')
    time.sleep(5)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    # model = torch.nn.DataParallel(model, device_ids=(0, 1))
    print(model.device_ids)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))
    if 'multi_step' in cfg['training']['loss']['name']:
        my_loss_fn = loss_fn(
            scale_weight=cfg['training']['loss']['scale_weight'],
            n_inp=2,
            weight=None,
            reduction='sum',
            bkargs=args)
    else:
        my_loss_fn = loss_fn(weight=None, reduction='sum', bkargs=args)

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = my_loss_fn(myinput=outputs, target=labels)

            loss.backward()
            optimizer.step()

            # gpu_profile(frame=sys._getframe(), event='line', arg=None)

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = my_loss_fn(myinput=outputs,
                                              target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
コード例 #12
0
def train(cfg, writer, logger):

    # Setup seeds for reproducing
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'], cfg['task'])
    data_path = cfg['data']['path']

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        img_norm=cfg['data']['img_norm'],
        # version = cfg['data']['version'],
        augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_norm=cfg['data']['img_norm'],
        # version=cfg['data']['version'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    if cfg['task'] == "seg":
        n_classes = t_loader.n_classes
        running_metrics_val = runningScoreSeg(n_classes)
    elif cfg['task'] == "depth":
        n_classes = 0
        running_metrics_val = runningScoreDepth()
    else:
        raise NotImplementedError('Task {} not implemented'.format(
            cfg['task']))

    # Setup Model
    model = get_model(cfg['model'], cfg['task'], n_classes).to(device)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            # checkpoint = torch.load(cfg['training']['resume'], map_location=lambda storage, loc: storage)  # load model trained on gpu on cpu
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    best_rel = 100.0
    # i = start_iter
    i = 0
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        print(len(trainloader))
        for (images, labels, img_path) in trainloader:
            start_ts = time.time()  # return current time stamp
            scheduler.step()
            model.train()  # set model to training mode
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()  #clear earlier gradients
            outputs = model(images)
            if cfg['model']['arch'] == "dispnet" and cfg['task'] == "depth":
                outputs = 1 / outputs

            loss = loss_fn(input=outputs, target=labels)  # compute loss
            loss.backward()  # backpropagation loss
            optimizer.step()  # optimizer parameter update

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.val / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or (
                    i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val,
                                img_path_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(
                            images_val
                        )  # [batch_size, n_classes, height, width]
                        if cfg['model']['arch'] == "dispnet" and cfg[
                                'task'] == "depth":
                            outputs = 1 / outputs

                        val_loss = loss_fn(input=outputs, target=labels_val
                                           )  # mean pixelwise loss in a batch

                        if cfg['task'] == "seg":
                            pred = outputs.data.max(1)[1].cpu().numpy(
                            )  # [batch_size, height, width]
                            gt = labels_val.data.cpu().numpy(
                            )  # [batch_size, height, width]
                        elif cfg['task'] == "depth":
                            pred = outputs.squeeze(1).data.cpu().numpy()
                            gt = labels_val.data.squeeze(1).cpu().numpy()
                        else:
                            raise NotImplementedError(
                                'Task {} not implemented'.format(cfg['task']))

                        running_metrics_val.update(gt=gt, pred=pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d val_loss: %.4f" %
                            (i + 1, val_loss_meter.avg))
                print("Iter %d val_loss: %.4f" % (i + 1, val_loss_meter.avg))

                # output scores
                if cfg['task'] == "seg":
                    score, class_iou = running_metrics_val.get_scores()
                    for k, v in score.items():
                        print(k, v)
                        sys.stdout.flush()
                        logger.info('{}: {}'.format(k, v))
                        writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)
                    for k, v in class_iou.items():
                        logger.info('{}: {}'.format(k, v))
                        writer.add_scalar('val_metrics/cls_{}'.format(k), v,
                                          i + 1)

                elif cfg['task'] == "depth":
                    val_result = running_metrics_val.get_scores()
                    for k, v in val_result.items():
                        print(k, v)
                        logger.info('{}: {}'.format(k, v))
                        writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)
                else:
                    raise NotImplementedError('Task {} not implemented'.format(
                        cfg['task']))

                val_loss_meter.reset()
                running_metrics_val.reset()

                save_model = False
                if cfg['task'] == "seg":
                    if score["Mean IoU : \t"] >= best_iou:
                        best_iou = score["Mean IoU : \t"]
                        save_model = True
                        state = {
                            "epoch": i + 1,
                            "model_state": model.state_dict(),
                            "optimizer_state": optimizer.state_dict(),
                            "scheduler_state": scheduler.state_dict(),
                            "best_iou": best_iou,
                        }

                if cfg['task'] == "depth":
                    if val_result["abs rel : \t"] <= best_rel:
                        best_rel = val_result["abs rel : \t"]
                        save_model = True
                        state = {
                            "epoch": i + 1,
                            "model_state": model.state_dict(),
                            "optimizer_state": optimizer.state_dict(),
                            "scheduler_state": scheduler.state_dict(),
                            "best_rel": best_rel,
                        }

                if save_model:
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
            i += 1
コード例 #13
0
def train(cfg, writer, logger, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', RNG_SEED))
    torch.cuda.manual_seed(cfg.get('seed', RNG_SEED))
    np.random.seed(cfg.get('seed', RNG_SEED))
    random.seed(cfg.get('seed', RNG_SEED))

    # Setup device
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device(args.device)

    # Setup Augmentations
    # augmentations = cfg['training'].get('augmentations', None)
    if cfg['data']['dataset'] in ['cityscapes']:
        augmentations = cfg['training'].get('augmentations',
                                            {'brightness': 63. / 255.,
                                             'saturation': 0.5,
                                             'contrast': 0.8,
                                             'hflip': 0.5,
                                             'rotate': 10,
                                             'rscalecropsquare': 704,  # 640, # 672, # 704,
                                             })
    elif cfg['data']['dataset'] in ['drive']:
        augmentations = cfg['training'].get('augmentations',
                                            {'brightness': 63. / 255.,
                                             'saturation': 0.5,
                                             'contrast': 0.8,
                                             'hflip': 0.5,
                                             'rotate': 180,
                                             'rscalecropsquare': 576,
                                             })
        # augmentations = cfg['training'].get('augmentations',
        #                                     {'rotate': 10, 'hflip': 0.5, 'rscalecrop': 512, 'gaussian': 0.5})
    else:
        augmentations = cfg['training'].get('augmentations', {'rotate': 10, 'hflip': 0.5})
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes, cfg['data']['void_class'] > 0)

    # Setup Model
    print('trying device {}'.format(device))
    model = get_model(cfg['model'], n_classes, args)  # .to(device)

    if cfg['model']['arch'] not in ['unetvgg16', 'unetvgg16gn', 'druvgg16', 'unetresnet50', 'unetresnet50bn',
                                    'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
        model.apply(weights_init)
    else:
        init_model(model)

    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    # if cfg['model']['arch'] in ['druresnet50syncedbn']:
    #     print('using synchronized batch normalization')
    #     time.sleep(5)
    #     patch_replication_callback(model)

    model = model.cuda()
    # model = torch.nn.DataParallel(model, device_ids=(3, 2))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items()
                        if k != 'name'}
    if cfg['model']['arch'] in ['unetvgg16', 'unetvgg16gn', 'druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
        optimizer = optimizer_cls([
            {'params': model.module.paramGroup1.parameters(), 'lr': optimizer_params['lr'] / 10},
            {'params': model.module.paramGroup2.parameters()}
        ], **optimizer_params)
    else:
        optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.warning(f"Model parameters in total: {sum([p.numel() for p in model.parameters()])}")
    logger.warning(f"Trainable parameters in total: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    weight = torch.ones(n_classes)
    if cfg['data'].get('void_class'):
        if cfg['data'].get('void_class') >= 0:
            weight[cfg['data'].get('void_class')] = 0.
    weight = weight.to(device)

    logger.info("Set the prediction weights as {}".format(weight))

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            # for param_group in optimizer.param_groups:
            #     print(param_group['lr'])
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            if cfg['model']['arch'] in ['reclast']:
                h0 = torch.ones([images.shape[0], args.hidden_size, images.shape[2], images.shape[3]],
                                dtype=torch.float32)
                h0.to(device)
                outputs = model(images, h0)

            elif cfg['model']['arch'] in ['recmid']:
                W, H = images.shape[2], images.shape[3]
                w = int(np.floor(np.floor(np.floor(W/2)/2)/2)/2)
                h = int(np.floor(np.floor(np.floor(H/2)/2)/2)/2)
                h0 = torch.ones([images.shape[0], args.hidden_size, w, h],
                                dtype=torch.float32)
                h0.to(device)
                outputs = model(images, h0)

            elif cfg['model']['arch'] in ['dru', 'sru']:
                W, H = images.shape[2], images.shape[3]
                w = int(np.floor(np.floor(np.floor(W/2)/2)/2)/2)
                h = int(np.floor(np.floor(np.floor(H/2)/2)/2)/2)
                h0 = torch.ones([images.shape[0], args.hidden_size, w, h],
                                dtype=torch.float32)
                h0.to(device)
                s0 = torch.ones([images.shape[0], n_classes, W, H],
                                dtype=torch.float32)
                s0.to(device)
                outputs = model(images, h0, s0)

            elif cfg['model']['arch'] in ['druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                W, H = images.shape[2], images.shape[3]
                w, h = int(W / 2 ** 4), int(H / 2 ** 4)
                if cfg['model']['arch'] in ['druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                    w, h = int(W / 2 ** 5), int(H / 2 ** 5)
                h0 = torch.ones([images.shape[0], args.hidden_size, w, h],
                                dtype=torch.float32, device=device)
                s0 = torch.zeros([images.shape[0], n_classes, W, H],
                                 dtype=torch.float32, device=device)
                outputs = model(images, h0, s0)

            else:
                outputs = model(images)

            loss = loss_fn(input=outputs, target=labels, weight=weight, bkargs=args)
            loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            # if use_grad_clip(cfg['model']['arch']):  #
            # if cfg['model']['arch'] in ['rcnn', 'rcnn2', 'rcnn3']:  #
            if use_grad_clip(cfg['model']['arch']):
                nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i + 1,
                                           cfg['training']['train_iters'], 
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                # print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i+1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                torch.backends.cudnn.benchmark = False
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
                        if args.benchmark:
                            if i_val > 10:
                                break
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        if cfg['model']['arch'] in ['reclast']:
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, images_val.shape[2], images_val.shape[3]],
                                            dtype=torch.float32)
                            h0.to(device)
                            outputs = model(images_val, h0)

                        elif cfg['model']['arch'] in ['recmid']:
                            W, H = images_val.shape[2], images_val.shape[3]
                            w = int(np.floor(np.floor(np.floor(W / 2) / 2) / 2) / 2)
                            h = int(np.floor(np.floor(np.floor(H / 2) / 2) / 2) / 2)
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h],
                                            dtype=torch.float32)
                            h0.to(device)
                            outputs = model(images_val, h0)

                        elif cfg['model']['arch'] in ['dru', 'sru']:
                            W, H = images_val.shape[2], images_val.shape[3]
                            w = int(np.floor(np.floor(np.floor(W / 2) / 2) / 2) / 2)
                            h = int(np.floor(np.floor(np.floor(H / 2) / 2) / 2) / 2)
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h],
                                            dtype=torch.float32)
                            h0.to(device)
                            s0 = torch.ones([images_val.shape[0], n_classes, W, H],
                                            dtype=torch.float32)
                            s0.to(device)
                            outputs = model(images_val, h0, s0)

                        elif cfg['model']['arch'] in ['druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                            W, H = images_val.shape[2], images_val.shape[3]
                            w, h = int(W / 2**4), int(H / 2**4)
                            if cfg['model']['arch'] in ['druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                                w, h = int(W / 2 ** 5), int(H / 2 ** 5)
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h],
                                            dtype=torch.float32)
                            h0.to(device)
                            s0 = torch.zeros([images_val.shape[0], n_classes, W, H],
                                             dtype=torch.float32)
                            s0.to(device)
                            outputs = model(images_val, h0, s0)

                        else:
                            outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val, bkargs=args)

                        if cfg['training']['loss']['name'] in ['multi_step_cross_entropy']:
                            pred = outputs[-1].data.max(1)[1].cpu().numpy()
                        else:
                            pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()
                        logger.debug('pred shape: ', pred.shape, '\t ground-truth shape:',gt.shape)
                        # IPython.embed()
                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())
                    # assert i_val > 0, "Validation dataset is empty for no reason."
                torch.backends.cudnn.benchmark = True
                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
                # IPython.embed()
                score, class_iou, _ = running_metrics_val.get_scores()
                for k, v in score.items():
                    # print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(),
                                             best_model_path(cfg))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                save_path = os.path.join(writer.file_writer.get_logdir(),
                                         "{}_{}_final_model.pkl".format(
                                             cfg['model']['arch'],
                                             cfg['data']['dataset']))
                torch.save(state, save_path)
                break
コード例 #14
0
def validate(cfg, args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.out_dir != "":
        if not os.path.exists(args.out_dir):
            os.mkdir(args.out_dir)
        if not os.path.exists(args.out_dir+'hmaps_bg'):
            os.mkdir(args.out_dir+'hmaps_bg')
        if not os.path.exists(args.out_dir+'hmaps_fg'):
            os.mkdir(args.out_dir+'hmaps_fg')
        if not os.path.exists(args.out_dir+'pred'):
            os.mkdir(args.out_dir+'pred')
        if not os.path.exists(args.out_dir+'gt'):
            os.mkdir(args.out_dir+'gt')
        if not os.path.exists(args.out_dir+'qry_images'):
            os.mkdir(args.out_dir+'qry_images')
        if not os.path.exists(args.out_dir+'sprt_images'):
            os.mkdir(args.out_dir+'sprt_images')
        if not os.path.exists(args.out_dir+'sprt_gt'):
            os.mkdir(args.out_dir+'sprt_gt')

    if args.fold != -1:
        cfg['data']['fold'] = args.fold

    fold = cfg['data']['fold']

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    loader = data_loader(
        data_path,
        split=cfg['data']['val_split'],
        is_transform=True,
        img_size=[cfg['data']['img_rows'],
                  cfg['data']['img_cols']],
        n_classes=cfg['data']['n_classes'],
        fold=cfg['data']['fold'],
        binary=args.binary,
        k_shot=cfg['data']['k_shot']
    )

    n_classes = loader.n_classes

    valloader = data.DataLoader(loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=0)
    if args.binary:
        running_metrics = runningScore(2)
        fp_list = {}
        tp_list = {}
        fn_list = {}
    else:
        running_metrics = runningScore(n_classes+1) #+1 indicate the novel class thats added each time

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)
    state = convert_state_dict(torch.load(args.model_path)["model_state"])
    model.load_state_dict(state)
    model.to(device)
    model.freeze_all_except_classifiers()

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items()
                        if k != 'name'}
    model.save_original_weights()

    alpha = 0.14139
    for i, (sprt_images, sprt_labels, qry_images, qry_labels,
            original_sprt_images, original_qry_images, cls_ind) in enumerate(valloader):

        cls_ind = int(cls_ind)
        print('Starting iteration ', i)
        start_time = timeit.default_timer()
        if args.out_dir != "":
            save_images(original_sprt_images, sprt_labels,
                        original_qry_images, i, args.out_dir)

        for si in range(len(sprt_images)):
            sprt_images[si] = sprt_images[si].to(device)
            sprt_labels[si] = sprt_labels[si].to(device)
        qry_images = qry_images.to(device)

        # 1- Extract embedding and add the imprinted weights
        if args.iterations_imp > 0:
            model.iterative_imprinting(sprt_images, qry_images, sprt_labels,
                                       alpha=alpha, itr=args.iterations_imp)
        else:
            model.imprint(sprt_images, sprt_labels, alpha=alpha, random=args.rand)

        optimizer = optimizer_cls(model.parameters(), **optimizer_params)
        scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])
        loss_fn = get_loss_function(cfg)
        print('Finetuning')
        for j in range(cfg['training']['train_iters']):
            for b in range(len(sprt_images)):
                torch.cuda.empty_cache()
                scheduler.step()
                model.train()
                optimizer.zero_grad()

                outputs = model(sprt_images[b])
                loss = loss_fn(input=outputs, target=sprt_labels[b])
                loss.backward()
                optimizer.step()

                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}"
                print_str = fmt_str.format(j,
                                       cfg['training']['train_iters'],
                                       loss.item())
                print(print_str)


        # 2- Infer on the query image
        model.eval()
        with torch.no_grad():
            outputs = model(qry_images)
            pred = outputs.data.max(1)[1].cpu().numpy()

        # Reverse the last imprinting (Few shot setting only not Continual Learning setup yet)
        model.reverse_imprinting()

        gt = qry_labels.numpy()
        if args.binary:
            gt,pred = post_process(gt, pred)

        if args.binary:
            if args.binary == 1:
                tp, fp, fn = running_metrics.update_binary_oslsm(gt, pred)

                if cls_ind in fp_list.keys():
                    fp_list[cls_ind] += fp
                else:
                    fp_list[cls_ind] = fp

                if cls_ind in tp_list.keys():
                    tp_list[cls_ind] += tp
                else:
                    tp_list[cls_ind] = tp

                if cls_ind in fn_list.keys():
                    fn_list[cls_ind] += fn
                else:
                    fn_list[cls_ind] = fn
            else:
                running_metrics.update(gt, pred)
        else:
            running_metrics.update(gt, pred)

        if args.out_dir != "":
            if args.binary:
                save_vis(outputs, pred, gt, i, args.out_dir, fg_class=1)
            else:
                save_vis(outputs, pred, gt, i, args.out_dir)

    if args.binary:
        if args.binary == 1:
            iou_list = [tp_list[ic]/float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1)) \
                         for ic in tp_list.keys()]
            print("Binary Mean IoU ", np.mean(iou_list))
        else:
            score, class_iou = running_metrics.get_scores()
            for k, v in score.items():
                print(k, v)
    else:
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            print(k, v)
        val_nclasses = model.n_classes + 1
        for i in range(val_nclasses):
            print(i, class_iou[i])
コード例 #15
0
ファイル: train_polar.py プロジェクト: yoyoyoohh/PolSAR_CD
def train(cfg, writer, logger):
    
    # Setup random seeds to a determinated value for reproduction
    # seed = 1337
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # np.random.seed(seed)
    # random.seed(seed)
    # np.random.default_rng(seed)

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.train_split,
        norm = cfg.data.norm,
        augments=data_aug
        )

    v_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.val_split,
        )
    train_data_len = len(t_loader)
    logger.info(f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}')

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size, 
                                  num_workers=cfg.train.n_workers, 
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader, 
                                batch_size=10, 
                                # persis
                                num_workers=cfg.train.n_workers,)

    # Setup Model
    device = f'cuda:{cfg.gpu[0]}'
    model = get_model(cfg.model, 2).to(device)
    input_size = (cfg.model.input_nbr, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=True)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)      #自动多卡运行,这个好用
    
    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in vars(cfg.train.optimizer).items()
                        if k not in ('name', 'wrap')}
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer, 'warp') and cfg.train.optimizer.wrap=='lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    loss_fn = get_loss_function(cfg)
    logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg.train.resume)
            )

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg.train.resume, checkpoint["epoch"]
                )
            )

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file or '_last_model' in file):
                # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(osp.join(resume_src_dir, file), resume_dst_dir, )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time() 
    train_val_start_time = time.time()
    model.train()   
    while it < train_iter:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1           
            file_a = file_a.to(device)            
            file_b = file_b.to(device)            
            label = label.to(device)            
            mask = mask.to(device)

            optimizer.zero_grad()
            # print(f'dtype: {file_a.dtype}')
            outputs = model(file_a, file_b)
            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()

            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()
            
            # record the acc of the minibatch
            pred = outputs.max(1)[1].cpu().numpy()
            runing_metrics_train.update(label.cpu().numpy(), pred, mask.cpu().numpy())

            train_time_meter.update(time.time() - train_start_time)

            if it % cfg.train.print_interval == 0:
                # acc of the samples between print_interval
                score, _ = runing_metrics_train.get_scores()
                train_cls_0_acc, train_cls_1_acc = score['Acc']
                fmt_str = "Iter [{:d}/{:d}]  train Loss: {:.4f}  Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}"
                print_str = fmt_str.format(it,
                                           train_iter,
                                           loss.item(),      #extracts the loss’s value as a Python float.
                                           train_time_meter.avg / cfg.train.batch_size,train_cls_0_acc, train_cls_1_acc)
                runing_metrics_train.reset()
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it)
                writer.add_scalars('metrics/train', {'cls_0':train_cls_0_acc, 'cls_1':train_cls_1_acc}, it)
                # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it)
                # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it)

            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()            # change behavior like drop out
                with torch.no_grad():   # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val, mask_val) in valloader:      
                        file_a_val = file_a_val.to(device)            
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max() returns the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        running_metrics_val.update(label_val.numpy(), pred, mask_val.numpy())
            
                        label_val = label_val.to(device)            
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                score, _ = running_metrics_val.get_scores()
                val_cls_0_acc, val_cls_1_acc = score['Acc']

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}")
                # lr_now = optimizer.param_groups[0]['lr']
                # logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)

                logger.info('0: {:.4f}\n1:{:.4f}'.format(val_cls_0_acc, val_cls_1_acc))
                writer.add_scalars('metrics/val', {'cls_0':val_cls_0_acc, 'cls_1':val_cls_1_acc}, it)
                # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it)
                # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it)

                val_loss_meter.reset()
                running_metrics_val.reset()

                # OA=score["Overall_Acc"]
                val_macro_OA = (val_cls_0_acc+val_cls_1_acc)/2
                if val_macro_OA >= best_macro_OA_now and it>200:
                    best_macro_OA_now = val_macro_OA
                    best_macro_OA_iter_now = it
                    state = {
                        "epoch": it,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_macro_OA_now": best_macro_OA_now,
                        'best_macro_OA_iter_now':best_macro_OA_iter_now,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg.model.arch,cfg.data.dataloader))
                    torch.save(state, save_path)

                    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
                    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter-it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            train_start_time = time.time() 

    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

    state = {
            "epoch": it,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_macro_OA_now": best_macro_OA_now,
            'best_macro_OA_iter_now':best_macro_OA_iter_now,
            }
    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader))
    torch.save(state, save_path)
コード例 #16
0
def train(cfg, writer, logger, run_id):

    # Setup random seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    torch.backends.cudnn.benchmark = True

    # Setup Augmentations
    augmentations = cfg['train'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataloader'])
    data_path = cfg['data']['path']

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(data_path,
                           transform=None,
                           split=cfg['data']['train_split'],
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        transform=None,
        split=cfg['data']['val_split'],
    )
    logger.info(
        f'num of train samples: {len(t_loader)} \nnum of val samples: {len(v_loader)}'
    )

    train_data_len = len(t_loader)
    batch_size = cfg['train']['batch_size']
    epoch = cfg['train']['train_epoch']
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')
    n_classes = t_loader.n_classes

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['train']['batch_size'],
                                  num_workers=cfg['train']['n_workers'],
                                  shuffle=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['train']['batch_size'],
                                num_workers=cfg['train']['n_workers'])

    # Setup Model
    model = get_model(cfg['model'], n_classes)
    logger.info("Using Model: {}".format(cfg['model']['arch']))
    device = f'cuda:{cuda_idx[0]}'
    model = model.to(device)
    model = torch.nn.DataParallel(model, device_ids=cuda_idx)  #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['train']['optimizer'].items() if k != 'name'
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    scheduler = get_scheduler(optimizer, cfg['train']['lr_schedule'])
    loss_fn = get_loss_function(cfg)
    # logger.info("Using loss {}".format(loss_fn))

    # set checkpoints
    start_iter = 0
    if cfg['train']['resume'] is not None:
        if os.path.isfile(cfg['train']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['train']['resume']))
            checkpoint = torch.load(cfg['train']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['train']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['train']['resume']))

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()
    time_meter_val = averageMeter()

    best_iou = 0
    flag = True

    val_rlt_f1 = []
    val_rlt_OA = []
    best_f1_till_now = 0
    best_OA_till_now = 0
    best_fwIoU_now = 0
    best_fwIoU_iter_till_now = 0

    # train
    it = start_iter
    model.train()
    while it <= train_iter and flag:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1
            start_ts = time.time()
            file_a = file_a.to(device)
            file_b = file_b.to(device)
            label = label.to(device)
            mask = mask.to(device)

            optimizer.zero_grad()
            outputs = model(file_a, file_b)

            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()
            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()

            train_time_meter.update(time.time() - start_ts)
            time_meter_val.update(time.time() - start_ts)

            if (it + 1) % cfg['train']['print_interval'] == 0:
                fmt_str = "train:\nIter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    it + 1,
                    train_iter,
                    loss.item(),  #extracts the loss’s value as a Python float.
                    train_time_meter.avg / cfg['train']['batch_size'])
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it + 1)

            if (it + 1) % cfg['train']['val_interval'] == 0 or \
               (it + 1) == train_iter:
                model.eval()  # change behavior like drop out
                with torch.no_grad():  # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val,
                         mask_val) in valloader:
                        file_a_val = file_a_val.to(device)
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max with return the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        gt = label_val.numpy()
                        running_metrics_val.update(gt, pred, mask_val)

                        label_val = label_val.to(device)
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs,
                                           target=label_val,
                                           mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                lr_now = optimizer.param_groups[0]['lr']
                logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)
                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it + 1)
                logger.info("Iter %d, val Loss: %.4f" %
                            (it + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()

                # for k, v in score.items():
                #     logger.info('{}: {}'.format(k, v))
                #     writer.add_scalar('val_metrics/{}'.format(k), v, it+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v,
                                      it + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                avg_f1 = score["Mean_F1"]
                OA = score["Overall_Acc"]
                fw_IoU = score["FreqW_IoU"]
                val_rlt_f1.append(avg_f1)
                val_rlt_OA.append(OA)

                if fw_IoU >= best_fwIoU_now and it > 200:
                    best_fwIoU_now = fw_IoU
                    correspond_meanIou = score["Mean_IoU"]
                    best_fwIoU_iter_till_now = it + 1

                    state = {
                        "epoch": it + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_fwIoU": best_fwIoU_now,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(
                            cfg['model']['arch'], cfg['data']['dataloader']))
                    torch.save(state, save_path)

                    logger.info("best_fwIoU_now =  %.8f" % (best_fwIoU_now))
                    logger.info("Best fwIoU Iter till now= %d" %
                                (best_fwIoU_iter_till_now))

                iter_time = time_meter_val.avg
                time_meter_val.reset()
                remain_time = iter_time * (train_iter - it)
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                print(train_time)

            model.train()
            if (it + 1) == train_iter:
                flag = False
                logger.info("Use the Sar_seg_band3,val_interval: 30")
                break
    logger.info("best_fwIoU_now =  %.8f" % (best_fwIoU_now))
    logger.info("Best fwIoU Iter till now= %d" % (best_fwIoU_iter_till_now))

    state = {
        "epoch": it + 1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_fwIoU": best_fwIoU_now,
    }
    save_path = os.path.join(
        writer.file_writer.get_logdir(),
        "{}_{}_last_model.pkl".format(cfg['model']['arch'],
                                      cfg['data']['dataloader']))
    torch.save(state, save_path)
コード例 #17
0
ファイル: train.py プロジェクト: hyzcn/MultiAgentPerception
        # import pdb; pdb.set_trace()

        # Setup optimizer
        optimizer_cls = get_optimizer(cfg)
        optimizer_params = {
            k: v
            for k, v in cfg["training"]["optimizer"].items() if k != "name"
        }
        optimizer = optimizer_cls(model.parameters(), **optimizer_params)
        logger.info("Using optimizer {}".format(optimizer))

        # Setup scheduler
        scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

        # Setup loss
        loss_fn = get_loss_function(cfg)
        logger.info("Using loss {}".format(loss_fn))

        # ================== TRAINING ==================
        if cfg['model']['arch'] == 'LearnWhen2Com':  # Our when2com
            trainer = Trainer_LearnWhen2Com(cfg, writer, logger, model,
                                            loss_fn, trainloader, valloader,
                                            optimizer, scheduler, device)
        elif cfg['model']['arch'] == 'LearnWho2Com':  # Our who2com
            trainer = Trainer_LearnWho2Com(cfg, writer, logger, model, loss_fn,
                                           trainloader, valloader, optimizer,
                                           scheduler, device)
        elif cfg['model']['arch'] == 'MIMOcom':  #
            trainer = Trainer_MIMOcom(cfg, writer, logger, model, loss_fn,
                                      trainloader, valloader, optimizer,
                                      scheduler, device)
コード例 #18
0
ファイル: train.py プロジェクト: llfl/MultiDepth
def train(cfg, writer, logger):

    # Setup random seeds
    torch.manual_seed(cfg.get('seed', 1860))
    torch.cuda.manual_seed(cfg.get('seed', 1860))
    np.random.seed(cfg.get('seed', 1860))
    random.seed(cfg.get('seed', 1860))

    # Setup device
    if cfg["device"]["use_gpu"]:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if not torch.cuda.is_available():
            logger.warning("CUDA not available, using CPU instead!")
    else:
        device = torch.device("cpu")

    # Setup augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)
    if "rcrop" in augmentations.keys():
        data_aug_val = get_composed_augmentations(
            {"rcrop": augmentations["rcrop"]})

    # Setup dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']
    if 'depth_scaling' not in cfg['data'].keys():
        cfg['data']['depth_scaling'] = None
    if 'max_depth' not in cfg['data'].keys():
        logger.warning(
            "Key d_max not found in configuration file! Using default value")
        cfg['data']['max_depth'] = 256
    if 'min_depth' not in cfg['data'].keys():
        logger.warning(
            "Key d_min not found in configuration file! Using default value")
        cfg['data']['min_depth'] = 1
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug,
                           depth_scaling=cfg['data']['depth_scaling'],
                           n_bins=cfg['data']['depth_bins'],
                           max_depth=cfg['data']['max_depth'],
                           min_depth=cfg['data']['min_depth'])

    v_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['val_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug_val,
                           depth_scaling=cfg['data']['depth_scaling'],
                           n_bins=cfg['data']['depth_bins'],
                           max_depth=cfg['data']['max_depth'],
                           min_depth=cfg['data']['min_depth'])

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['validation']['batch_size'],
                                num_workers=cfg['validation']['n_workers'],
                                shuffle=True,
                                drop_last=True)

    # Check selected tasks
    if sum(cfg["data"]["tasks"].values()) > 1:
        logger.info("Running multi-task training with config: {}".format(
            cfg["data"]["tasks"]))

    # Get output dimension of the network's final layer
    n_classes_d_cls = None
    if cfg["data"]["tasks"]["d_cls"]:
        n_classes_d_cls = t_loader.n_classes_d_cls

    # Setup metrics for validation
    if cfg["data"]["tasks"]["d_cls"]:
        running_metrics_val_d_cls = runningScore(n_classes_d_cls)
    if cfg["data"]["tasks"]["d_reg"]:
        running_metrics_val_d_reg = running_side_score()

    # Setup model
    model = get_model(cfg['model'],
                      cfg["data"]["tasks"],
                      n_classes_d_cls=n_classes_d_cls).to(device)
    # model = d_regResNet().to(device)

    # Setup multi-GPU support
    n_gpus = torch.cuda.device_count()
    if n_gpus > 1:
        logger.info("Running multi-gpu training on {} GPUs".format(n_gpus))
        model = torch.nn.DataParallel(model, device_ids=range(n_gpus))

    # Setup multi-task loss
    task_weights = {}
    update_weights = True if \
        cfg["training"]["task_weight_policy"] == 'update' else False
    for task, weight in cfg["training"]["task_weight_init"].items():
        task_weights[task] = torch.tensor(weight).float()
        task_weights[task] = task_weights[task].to(device)
        task_weights[task] = task_weights[task].requires_grad_(update_weights)
    logger.info("Task weights were initialized with {}".format(
        cfg["training"]["task_weight_init"]))

    # Setup optimizer and lr_scheduler
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    objective_params = list(model.parameters()) + list(task_weights.values())
    optimizer = optimizer_cls(objective_params, **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])
    logger.info("Using learning-rate scheduler {}".format(scheduler))

    # Setup task-specific loss functions
    # logger.debug("setting loss functions")
    loss_fns = {}
    for task, selected in cfg["data"]["tasks"].items():
        if selected:
            logger.info("Task " + task + " was selected for training.")
            loss_fn = get_loss_function(cfg, task)
            logger.info("Using loss function {} for task {}".format(
                loss_fn, task))
            loss_fns[task] = loss_fn

    # Load weights from old checkpoint if set
    # logger.debug("checking for resume checkpoint")
    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            logger.info("Loading file...")
            checkpoint = torch.load(cfg['training']['resume'],
                                    map_location="cpu")
            logger.info("Loading model...")
            model.load_state_dict(checkpoint["model_state"])
            model.to("cpu")
            model.to(device)
            logger.info("Restoring task weights...")
            task_weights = checkpoint["task_weights"]
            for task, state in task_weights.items():
                # task_weights[task] = state.to(device)
                task_weights[task] = torch.tensor(state.data).float()
                task_weights[task] = task_weights[task].to(device)
                task_weights[task] = task_weights[task].requires_grad_(
                    update_weights)
            logger.info("Loading scheduler...")
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            #            scheduler.to("cpu")
            start_iter = checkpoint["iteration"]

            # Add loaded parameters to optimizer
            # NOTE task_weights will not update otherwise!
            logger.info("Loading optimizer...")
            optimizer_cls = get_optimizer(cfg)
            objective_params = list(model.parameters()) + \
                list(task_weights.values())
            optimizer = optimizer_cls(objective_params, **optimizer_params)
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            # for state in optimizer.state.values():
            #     for k, v in state.items():
            #         if torch.is_tensor(v):
            #             state[k] = v.to(device)

            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["iteration"]))
        else:
            logger.error(
                "No checkpoint found at '{}'. Re-initializing params!".format(
                    cfg['training']['resume']))

    # Initialize meters for various metrics
    # logger.debug("initializing metrics")
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    # Setup other utility variables
    i = start_iter
    flag = True
    timer_training_start = time.time()

    logger.info("Starting training phase...")

    logger.debug("model device cuda?")
    logger.debug(next(model.parameters()).is_cuda)
    logger.debug("d_reg weight device:")
    logger.debug(task_weights["d_reg"].device)
    logger.debug("cls weight device:")
    logger.debug(task_weights["d_cls"].device)

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:

            start_ts = time.time()
            scheduler.step()
            model.train()

            # Forward pass
            # logger.debug("sending images to device")
            images = images.to(device)
            optimizer.zero_grad()
            # logger.debug("forward pass")
            outputs = model(images)

            # Clip predicted depth to min/max
            # logger.debug("clamping outputs")
            if cfg["data"]["tasks"]["d_reg"]:
                if cfg["data"]["depth_scaling"] is not None:
                    if cfg["data"]["depth_scaling"] == "clip":
                        logger.warning("Using deprecated clip function!")
                        outputs["d_reg"] = torch.clamp(
                            outputs["d_reg"], 0, cfg["data"]["max_depth"])

            # Calculate single-task losses
            # logger.debug("calculate loss")
            st_loss = {}
            for task, loss_fn in loss_fns.items():
                labels[task] = labels[task].to(device)
                st_loss[task] = loss_fn(input=outputs[task],
                                        target=labels[task])

            # Calculate multi-task loss
            # logger.debug("calculate mt loss")
            mt_loss = 0
            if len(st_loss) > 1:
                for task, loss in st_loss.items():
                    s = task_weights[task]  # s := log(sigma^2)
                    r = s * 0.5  # regularization term
                    if task in ["d_cls"]:
                        w = torch.exp(-s)  # weighting (class.)
                    elif task in ["d_reg"]:
                        w = 0.5 * torch.exp(-s)  # weighting (regr.)
                    else:
                        raise ValueError("Weighting not implemented!")
                    mt_loss += loss * w + r
            else:
                mt_loss = list(st_loss.values())[0]

            # Backward pass
            # logger.debug("backward pass")
            mt_loss.backward()
            # logger.debug("update weights")
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            # Output current training status
            # logger.debug("write log")
            if i == 0 or (i + 1) % cfg['training']['print_interval'] == 0:
                pad = str(len(str(cfg['training']['train_iters'])))
                print_str = ("Training Iteration: [{:>" + pad + "d}/{:d}]" +
                             "  Loss: {:>14.4f}" +
                             "  Time/Image: {:>7.4f}").format(
                                 i + 1, cfg['training']['train_iters'],
                                 mt_loss.item(), time_meter.avg /
                                 cfg['training']['batch_size'])
                logger.info(print_str)

                # Add training status to summaries
                writer.add_scalar('learning_rate',
                                  scheduler.get_lr()[0], i + 1)
                writer.add_scalar('batch_size', cfg['training']['batch_size'],
                                  i + 1)
                writer.add_scalar('loss/train_loss', mt_loss.item(), i + 1)
                for task, loss in st_loss.items():
                    writer.add_scalar("loss/single_task/" + task, loss, i + 1)
                for task, weight in task_weights.items():
                    writer.add_scalar("task_weights/" + task, weight, i + 1)
                time_meter.reset()

                # Add latest input image to summaries
                train_input = images[0].cpu().numpy()[::-1, :, :]
                writer.add_image("training/input", train_input, i + 1)

                # Add d_cls predictions and gt for latest sample to summaries
                if cfg["data"]["tasks"]["d_cls"]:
                    train_pred = outputs["d_cls"].detach().cpu().numpy().max(
                        0)[1].astype(np.uint8)
                    # train_pred = np.array(outputs["d_cls"][0].data.max(0)[1],
                    #                       dtype=np.uint8)
                    train_pred = t_loader.decode_segmap(train_pred)
                    train_pred = torch.tensor(np.rollaxis(train_pred, 2, 0))
                    writer.add_image("training/d_cls/prediction", train_pred,
                                     i + 1)

                    train_gt = t_loader.decode_segmap(
                        labels["d_cls"][0].data.cpu().numpy())
                    train_gt = torch.tensor(np.rollaxis(train_gt, 2, 0))
                    writer.add_image("training/d_cls/label", train_gt, i + 1)

                # Add d_reg predictions and gt for latest sample to summaries
                if cfg["data"]["tasks"]["d_reg"]:
                    train_pred = outputs["d_reg"][0]
                    train_pred = np.array(train_pred.data.cpu().numpy())
                    train_pred = t_loader.visualize_depths(
                        t_loader.restore_metric_depths(train_pred))
                    writer.add_image("training/d_reg/prediction", train_pred,
                                     i + 1)

                    train_gt = labels["d_reg"][0].data.cpu().numpy()
                    train_gt = t_loader.visualize_depths(
                        t_loader.restore_metric_depths(train_gt))
                    if len(train_gt.shape) < 3:
                        train_gt = np.expand_dims(train_gt, axis=0)
                    writer.add_image("training/d_reg/label", train_gt, i + 1)

            # Run mid-training validation
            if (i + 1) % cfg['training']['val_interval'] == 0:
                # or (i + 1) == cfg['training']['train_iters']:

                # Output current status
                # logger.debug("Training phase took " + str(timedelta(seconds=time.time() - timer_training_start)))
                timer_validation_start = time.time()
                logger.info("Validating model at training iteration" +
                            " {}...".format(i + 1))

                # Evaluate validation set
                model.eval()
                with torch.no_grad():
                    i_val = 0
                    pbar = tqdm(total=len(valloader), unit="batch")
                    for (images_val, labels_val) in valloader:

                        # Forward pass
                        images_val = images_val.to(device)
                        outputs_val = model(images_val)

                        # Clip predicted depth to min/max
                        if cfg["data"]["tasks"]["d_reg"]:
                            if cfg["data"]["depth_scaling"] is None:
                                logger.warning(
                                    "Using deprecated clip function!")
                                outputs_val["d_reg"] = torch.clamp(
                                    outputs_val["d_reg"], 0,
                                    cfg["data"]["max_depth"])
                            else:
                                outputs_val["d_reg"] = torch.clamp(
                                    outputs_val["d_reg"], 0, 1)

                        # Calculate single-task losses
                        st_loss_val = {}
                        for task, loss_fn in loss_fns.items():
                            labels_val[task] = labels_val[task].to(device)
                            st_loss_val[task] = loss_fn(
                                input=outputs_val[task],
                                target=labels_val[task])

                        # Calculate multi-task loss
                        mt_loss_val = 0
                        if len(st_loss) > 1:
                            for task, loss_val in st_loss_val.items():
                                s = task_weights[task]
                                r = s * 0.5
                                if task in ["d_cls"]:
                                    w = torch.exp(-s)
                                elif task in ["d_reg"]:
                                    w = 0.5 * torch.exp(-s)
                                else:
                                    raise ValueError(
                                        "Weighting not implemented!")
                                mt_loss_val += loss_val * w + r
                        else:
                            mt_loss_val = list(st_loss.values())[0]

                        # Accumulate metrics for summaries
                        val_loss_meter.update(mt_loss_val.item())

                        if cfg["data"]["tasks"]["d_cls"]:
                            running_metrics_val_d_cls.update(
                                labels_val["d_cls"].data.cpu().numpy(),
                                outputs_val["d_cls"].data.cpu().numpy().argmax(
                                    1))

                        if cfg["data"]["tasks"]["d_reg"]:
                            running_metrics_val_d_reg.update(
                                v_loader.restore_metric_depths(
                                    outputs_val["d_reg"].data.cpu().numpy()),
                                v_loader.restore_metric_depths(
                                    labels_val["d_reg"].data.cpu().numpy()))

                        # Update progressbar
                        i_val += 1
                        pbar.update()

                        # Stop validation early if max_iter key is set
                        if "max_iter" in cfg["validation"].keys() and \
                                i_val >= cfg["validation"]["max_iter"]:
                            logger.warning("Stopped validation early " +
                                           "because max_iter was reached")
                            break

                # Add sample input images from latest batch to summaries
                num_img_samples_val = min(len(images_val), NUM_IMG_SAMPLES)
                for cur_s in range(0, num_img_samples_val):
                    val_input = images_val[cur_s].cpu().numpy()[::-1, :, :]
                    writer.add_image(
                        "validation_sample_" + str(cur_s + 1) + "/input",
                        val_input, i + 1)

                    # Add predictions/ground-truth for d_cls to summaries
                    if cfg["data"]["tasks"]["d_cls"]:
                        val_pred = outputs_val["d_cls"][cur_s].data.max(0)[1]
                        val_pred = np.array(val_pred, dtype=np.uint8)
                        val_pred = t_loader.decode_segmap(val_pred)
                        val_pred = torch.tensor(np.rollaxis(val_pred, 2, 0))
                        writer.add_image(
                            "validation_sample_" + str(cur_s + 1) +
                            "/prediction_d_cls", val_pred, i + 1)
                        val_gt = t_loader.decode_segmap(
                            labels_val["d_cls"][cur_s].data.cpu().numpy())
                        val_gt = torch.tensor(np.rollaxis(val_gt, 2, 0))
                        writer.add_image(
                            "validation_sample_" + str(cur_s + 1) +
                            "/label_d_cls", val_gt, i + 1)

                # Add predictions/ground-truth for d_reg to summaries
                    if cfg["data"]["tasks"]["d_reg"]:
                        val_pred = outputs_val["d_reg"][cur_s].cpu().numpy()
                        val_pred = v_loader.visualize_depths(
                            v_loader.restore_metric_depths(val_pred))
                        writer.add_image(
                            "validation_sample_" + str(cur_s + 1) +
                            "/prediction_d_reg", val_pred, i + 1)

                        val_gt = labels_val["d_reg"][cur_s].data.cpu().numpy()
                        val_gt = v_loader.visualize_depths(
                            v_loader.restore_metric_depths(val_gt))
                        if len(val_gt.shape) < 3:
                            val_gt = np.expand_dims(val_gt, axis=0)
                        writer.add_image(
                            "validation_sample_" + str(cur_s + 1) +
                            "/label_d_reg", val_gt, i + 1)

                # Add evaluation metrics for d_cls predictions to summaries
                if cfg["data"]["tasks"]["d_cls"]:
                    score, class_iou = running_metrics_val_d_cls.get_scores()
                    for k, v in score.items():
                        writer.add_scalar(
                            'validation/d_cls_metrics/{}'.format(k[:-3]), v,
                            i + 1)
                        for k, v in class_iou.items():
                            writer.add_scalar(
                                'validation/d_cls_metrics/class_{}'.format(k),
                                v, i + 1)
                    running_metrics_val_d_cls.reset()

                # Add evaluation metrics for d_reg predictions to summaries
                if cfg["data"]["tasks"]["d_reg"]:
                    writer.add_scalar('validation/d_reg_metrics/rel',
                                      running_metrics_val_d_reg.rel, i + 1)
                    running_metrics_val_d_reg.reset()

                # Add validation loss to summaries
                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)

                # Output current status
                logger.info(
                    ("Validation Loss at Iteration {}: " + "{:>14.4f}").format(
                        i + 1, val_loss_meter.avg))
                val_loss_meter.reset()
                # logger.debug("Validation phase took {}".format(timedelta(seconds=time.time() - timer_validation_start)))
                timer_training_start = time.time()

                # Close progressbar
                pbar.close()

            # Save checkpoint
            if (i + 1) % cfg['training']['checkpoint_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters'] or \
               i == 0:
                state = {
                    "iteration": i + 1,
                    "model_state": model.state_dict(),
                    "task_weights": task_weights,
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict()
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_checkpoint_iter_".format(cfg['model']['arch'],
                                                    cfg['data']['dataset']) +
                    str(i + 1) + ".pkl")
                torch.save(state, save_path)
                logger.info("Saved checkpoint at iteration {} to: {}".format(
                    i + 1, save_path))

            # Stop training if current iteration == max iterations
            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break

            i += 1
コード例 #19
0
def train(cfg, writer, logger, start_iter=0, model_only=False, gpu=-1, save_dir=None):

    # Setup seeds and config
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))
    
    # Setup device
    if gpu == -1:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cuda:%d" %gpu if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    if cfg["data"]["dataset"] == "softmax_cityscapes_convention":
        data_aug = get_composed_augmentations_softmax(augmentations)
    else:
        data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        config = cfg["data"],
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )
    v_loader = data_loader(
        data_path,
        config = cfg["data"],
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    sampler = None
    if "sampling" in cfg["data"]:
        sampler = data.WeightedRandomSampler(
            weights = get_sampling_weights(t_loader, cfg["data"]["sampling"]),
            num_samples = len(t_loader),
            replacement = True
        )
    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        sampler=sampler,
        shuffle=sampler==None,
    )
    valloader = data.DataLoader(
        v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"]
    )

    # Setup Metrics
    running_metrics_val = {"seg": runningScoreSeg(n_classes)}
    if "classifiers" in cfg["data"]:
        for name, classes in cfg["data"]["classifiers"].items():
            running_metrics_val[name] = runningScoreClassifier( len(classes) )
    if "bin_classifiers" in cfg["data"]:
        for name, classes in cfg["data"]["bin_classifiers"].items():
            running_metrics_val[name] = runningScoreClassifier(2)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print( 'Parameters:',total_params )

    if gpu == -1:
        model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    else:
        model = torch.nn.DataParallel(model, device_ids=[gpu])
    
    model.apply(weights_init)
    pretrained_path='weights/hardnet_petite_base.pth'
    weights = torch.load(pretrained_path)
    model.module.base.load_state_dict(weights)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k != "name"}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    print("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])
    loss_dict = get_loss_function(cfg, device)

    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["resume"])
            )
            checkpoint = torch.load(cfg["training"]["resume"], map_location=device)
            model.load_state_dict(checkpoint["model_state"], strict=False)
            if not model_only:
                optimizer.load_state_dict(checkpoint["optimizer_state"])
                scheduler.load_state_dict(checkpoint["scheduler_state"])
                start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg["training"]["resume"], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg["training"]["resume"]))

    if cfg["training"]["finetune"] is not None:
        if os.path.isfile(cfg["training"]["finetune"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["finetune"])
            )
            checkpoint = torch.load(cfg["training"]["finetune"])
            model.load_state_dict(checkpoint["model_state"])

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True
    loss_all = 0
    loss_n = 0

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, label_dict, _) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()

            images = images.to(device)
            optimizer.zero_grad()
            output_dict = model(images)

            loss = compute_loss(    # considers key names in loss_dict and output_dict
                loss_dict, images, label_dict, output_dict, device, t_loader
            )
            
            loss.backward()         # backprops sum of loss tensors, frozen components will have no grad_fn
            optimizer.step()
            c_lr = scheduler.get_lr()

            if i%1000 == 0:             # log images, seg ground truths, predictions
                pred_array = output_dict["seg"].data.max(1)[1].cpu().numpy()
                gt_array = label_dict["seg"].data.cpu().numpy()
                softmax_gt_array = None
                if "softmax" in label_dict:
                    softmax_gt_array = label_dict["softmax"].data.max(1)[1].cpu().numpy()
                write_images_to_board(t_loader, images, gt_array, pred_array, i, name = 'train', softmax_gt = softmax_gt_array)

                if save_dir is not None:
                    image_array = images.data.cpu().numpy().transpose(0, 2, 3, 1)
                    write_images_to_dir(t_loader, image_array, gt_array, pred_array, i, save_dir, name = 'train', softmax_gt = softmax_gt_array)

            time_meter.update(time.time() - start_ts)
            loss_all += loss.item()
            loss_n += 1
            
            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}  lr={:.6f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss_all / loss_n,
                    time_meter.avg / cfg["training"]["batch_size"],
                    c_lr[0],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (i + 1) == cfg["training"][
                "train_iters"
            ]:
                torch.cuda.empty_cache()
                model.eval() # set batchnorm and dropouts to work in eval mode
                loss_all = 0
                loss_n = 0
                with torch.no_grad(): # Deactivate torch autograd engine, less memusage
                    for i_val, (images_val, label_dict_val, _) in tqdm(enumerate(valloader)):
                        
                        images_val = images_val.to(device)
                        output_dict = model(images_val)
                        
                        val_loss = compute_loss(
                            loss_dict, images_val, label_dict_val, output_dict, device, v_loader
                        )
                        val_loss_meter.update(val_loss.item())

                        for name, metrics in running_metrics_val.items():
                            gt_array = label_dict_val[name].data.cpu().numpy()
                            if name+'_loss' in cfg['training'] and cfg['training'][name+'_loss']['name'] == 'l1':  # for binary classification
                                pred_array = output_dict[name].data.cpu().numpy()
                                pred_array = np.sign(pred_array)
                                pred_array[pred_array == -1] = 0
                                gt_array[gt_array == -1] = 0
                            else:
                                pred_array = output_dict[name].data.max(1)[1].cpu().numpy()

                            metrics.update(gt_array, pred_array)

                softmax_gt_array = None # log validation images
                pred_array = output_dict["seg"].data.max(1)[1].cpu().numpy()
                gt_array = label_dict_val["seg"].data.cpu().numpy()
                if "softmax" in label_dict_val:
                    softmax_gt_array = label_dict_val["softmax"].data.max(1)[1].cpu().numpy()
                write_images_to_board(v_loader, images_val, gt_array, pred_array, i, 'validation', softmax_gt = softmax_gt_array)
                if save_dir is not None:
                    images_val = images_val.cpu().numpy().transpose(0, 2, 3, 1)
                    write_images_to_dir(v_loader, images_val, gt_array, pred_array, i, save_dir, name='validation', softmax_gt = softmax_gt_array)

                logger.info("Iter %d Val Loss: %.4f" % (i + 1, val_loss_meter.avg))
                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)

                for name, metrics in running_metrics_val.items():
                    
                    overall, classwise = metrics.get_scores()
                    
                    for k, v in overall.items():
                        logger.info("{}_{}: {}".format(name, k, v))
                        writer.add_scalar("val_metrics/{}_{}".format(name, k), v, i + 1)

                        if k == cfg["training"]["save_metric"]:
                            curr_performance = v

                    for metric_name, metric in classwise.items():
                        for k, v in metric.items():
                            logger.info("{}_{}_{}: {}".format(name, metric_name, k, v))
                            writer.add_scalar("val_metrics/{}_{}_{}".format(name, metric_name, k), v, i + 1)

                    metrics.reset()
                
                state = {
                      "epoch": i + 1,
                      "model_state": model.state_dict(),
                      "optimizer_state": optimizer.state_dict(),
                      "scheduler_state": scheduler.state_dict(),
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_checkpoint.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                )
                torch.save(state, save_path)

                if curr_performance >= best_iou:
                    best_iou = curr_performance
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
                torch.cuda.empty_cache()

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
コード例 #20
0
def eval(cfg, writer, logger, logdir):

    # Setup seeds
    #torch.manual_seed(cfg.get("seed", 1337))
    #torch.cuda.manual_seed(cfg.get("seed", 1337))
    #np.random.seed(cfg.get("seed", 1337))
    #random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataloader_type"])

    data_root = cfg["data"]["data_root"]
    presentation_root = cfg["data"]["presentation_root"]

    v_loader = data_loader(data_root=data_root,
                           presentation_root=presentation_root,
                           is_transform=True,
                           img_size=(cfg["data"]["img_rows"],
                                     cfg["data"]["img_cols"]),
                           augmentations=data_aug,
                           test_mode=True)

    n_classes = v_loader.n_classes

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"],
                                shuffle=False)

    # Setup Metrics
    # running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes, defaultParams).to(device)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    # train_loss_meter = averageMeter()
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = 0
    pres_results = [
    ]  # a final list of all <image, label, output> of all presentations
    img_list = []

    while i < cfg["training"]["num_presentations"]:

        #                 #
        #  TESTING PHASE  #
        #                 #
        i += 1

        training_state_dict = model.state_dict()
        hebb = model.initialZeroHebb().to(device)
        valloader.dataset.random_select()
        start_ts = time.time()

        for idx, (images_val, labels_val) in enumerate(
                valloader, 1):  # get a single test presentation

            img = torchvision.utils.make_grid(images_val).numpy()
            img = np.transpose(img, (1, 2, 0))
            img = img[:, :, ::-1]
            img_list.append(img)
            pres_results.append(decode_segmap(labels_val.numpy()))
            images_val = images_val.to(device)
            labels_val = labels_val.to(device)

            if idx <= 5:
                model.eval()
                with torch.no_grad():
                    outputs, hebb = model(images_val,
                                          labels_val,
                                          hebb,
                                          device,
                                          test_mode=False)
            else:
                model.train()
                optimizer.zero_grad()
                outputs, hebb = model(images_val,
                                      labels_val,
                                      hebb,
                                      device,
                                      test_mode=True)
                loss = loss_fn(input=outputs, target=labels_val)
                loss.backward()
                optimizer.step()

                pred = outputs.data.max(1)[1].cpu().numpy()
                gt = labels_val.data.cpu().numpy()

                running_metrics_val.update(gt, pred)
                val_loss_meter.update(loss.item())

                # Turning the image, label, and output into plottable formats
                '''img = torchvision.utils.make_grid(images_val.cpu()).numpy()
                img = np.transpose(img, (1, 2, 0))
                img = img[:, :, ::-1]
                print("img.shape",img.shape)
                print("gt.shape and type",gt.shape, gt.dtype)
                print("pred.shape and type",pred.shape, pred.dtype)'''

                cla, cnt = np.unique(pred, return_counts=True)
                print("Unique classes predicted = {}, counts = {}".format(
                    cla, cnt))
                #pres_results.append(img)
                #pres_results.append(decode_segmap(gt))
                pres_results.append(decode_segmap(pred))

        time_meter.update(time.time() -
                          start_ts)  # -> time taken per presentation
        model.load_state_dict(
            training_state_dict)  # revert back to training parameters

        # Display presentations stats
        fmt_str = "Pres [{:d}/{:d}]  Loss: {:.4f}  Time/Pres: {:.4f}"
        print_str = fmt_str.format(
            i + 1,
            cfg["training"]["num_presentations"],
            loss.item(),
            time_meter.avg / cfg["training"]["batch_size"],
        )
        print(print_str)
        logger.info(print_str)
        writer.add_scalar("loss/test_loss", loss.item(), i + 1)
        time_meter.reset()

        writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
        logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

        # Display presentation metrics
        score, class_iou = running_metrics_val.get_scores()
        for k, v in score.items():
            print(k, v)
            logger.info("{}: {}".format(k, v))
            writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

        #for k, v in class_iou.items():
        #    logger.info("{}: {}".format(k, v))
        #    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

        val_loss_meter.reset()
        running_metrics_val.reset()

    # save presentations to a png image file
    save_presentations(pres_results=pres_results,
                       num_pres=cfg["training"]["num_presentations"],
                       num_col=7,
                       logdir=logdir,
                       name="pre_results.png")
    save_presentations(pres_results=img_list,
                       num_pres=cfg["training"]["num_presentations"],
                       num_col=6,
                       logdir=logdir,
                       name="img_list.png")
コード例 #21
0
ファイル: train_dep.py プロジェクト: lwawrla/pytorch-semseg
def train(cfg, writer, logger):
    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = cityscapesLoader
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()

    # get loss_seg meter and also loss_dep meter

    val_loss_meter = averageMeter()
    # loss_seg_meter = averageMeter()
    # loss_dep_meter = averageMeter()
    time_meter = averageMeter()
    acc_result_total = averageMeter()
    acc_result_correct = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, masks, depths) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            depths = depths.to(device)

            # print(images.shape)
            optimizer.zero_grad()
            outputs = model(images).squeeze(1)

            # -----------------------------------------------------------------
            # add depth loss

            # -----------------------------------------------------------------
            # MSE loss
            # loss_dep = F.mse_loss(input=outputs[:, -1,:,:], target=depths, reduction='mean')

            # -----------------------------------------------------------------
            # Berhu loss; loss_dep = loss
            loss = berhu_loss_function(prediction=outputs, target=depths)
            masks = masks.type(torch.cuda.ByteTensor)
            loss = torch.sum(loss[masks]) / torch.sum(masks)

            # -----------------------------------------------------------------

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}] loss_dep: {:.4f} Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg["training"]["train_iters"], loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:

                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, masks_val,
                                depths_val) in enumerate(valloader):
                        images_val = images_val.to(device)

                        # add depth to device
                        depths_val = depths_val.to(device)

                        outputs = model(images_val).squeeze(1)
                        # depths_val = depths_val.data.resize_(depths_val.size(0), outputs.size(2), outputs.size(3))

                        # -----------------------------------------------------------------
                        # berhu loss function
                        val_loss = berhu_loss_function(prediction=outputs,
                                                       target=depths_val)
                        masks_val = masks_val.type(torch.cuda.ByteTensor)
                        val_loss = val_loss.type(torch.cuda.ByteTensor)
                        print('val_loss1 is', val_loss)
                        val_loss = torch.sum(
                            val_loss[masks_val]) / torch.sum(masks_val)
                        print('val_loss2 is', val_loss)

                        # -----------------------------------------------------------------
                        # Update

                        val_loss_meter.update(val_loss.item())

                        outputs = outputs.cpu().numpy()
                        depths_val = depths_val.cpu().numpy()
                        masks_val = masks_val.cpu().numpy()

                        # depths_val = depths_val.type(torch.cuda.FloatTensor)
                        # outputs = outputs.type(torch.cuda.FloatTensor)

                        # -----------------------------------------------------------------
                        # Try the following against error:
                        # RuntimeWarning: invalid value encountered in double_scalars: acc = np.diag(hist).sum() / hist.sum()
                        # Similar error: https://github.com/meetshah1995/pytorch-semseg/issues/118

                        acc_1 = outputs / depths_val
                        acc_2 = 1 / acc_1
                        acc_threshold = np.maximum(acc_1, acc_2)

                        acc_result_total.update(np.sum(masks_val))
                        acc_result_correct.update(
                            np.sum(
                                np.logical_and(acc_threshold < 1.25,
                                               masks_val)))

                print("Iter {:d}, val_loss {:.4f}".format(
                    i + 1, val_loss_meter.avg))
                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                acc_result = float(acc_result_correct.sum) / float(
                    acc_result_total.sum)
                print("Iter {:d}, acc_1.25 {:.4f}".format(i + 1, acc_result))
                logger.info("Iter %d acc_1.25: %.4f" % (i + 1, acc_result))

                # -----------------------------------------------------------------

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                acc_result_total.reset()
                acc_result_correct.reset()

                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

                    # insert print function to see if the losses are correct

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
コード例 #22
0
ファイル: train_SE.py プロジェクト: czha5168/pytorch-semseg
def train(cfg, writer, logger):
    # Setup dataset split before setting up the seed for random
    if cfg['data']['dataset'] == 'miccai2008':
        split_info = init_data_split_miccai2008(
            cfg['data']['path'])  # miccai2008 dataset
    elif cfg['data']['dataset'] == 'sasha':
        split_info = init_data_split_sasha(
            cfg['data']['path'])  # miccai2008 dataset

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Cross Entropy Weight
    weight = prep_class_val_weights(cfg['training']['cross_entropy_ratio'])

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    print(('augmentations_cfg:', augmentations))
    data_aug = get_composed_augmentations3d(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug,
                           split_info=split_info,
                           patch_size=cfg['training']['patch_size'],
                           mods=cfg['data']['mods'])
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['val_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           split_info=split_info,
                           patch_size=cfg['training']['patch_size'],
                           mods=cfg['data']['mods'])

    n_classes = t_loader.n_classes

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=False)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)
    model.apply(weights_init)
    params = sum([
        np.prod(p.size())
        for p in filter(lambda p: p.requires_grad, model.parameters())
    ]) / 1e6
    print('NumOfParams:{}M'.format(params))
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))
    softmax_function = nn.Softmax(dim=1)

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i_train_iter = start_iter

    display('Training from {}th iteration\n'.format(i_train_iter))
    while i_train_iter < cfg['training']['train_iters']:
        i_batch_idx = 0
        train_iter_start_time = time.time()
        for (images, labels, case_index_list) in trainloader:
            start_ts_network = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs_FM = model(images)

            #print('Unique on labels:{}'.format(np.unique(labels.data.cpu().numpy())))    #[0, 1]
            #print('Unique on outputs:{}'.format(np.unique(outputs_FM.data.cpu().numpy())))  #[-1.15, +0.39]
            log('TrainIter=> images.size():{} labels.size():{} | outputs.size():{}'
                .format(images.size(), labels.size(), outputs_FM.size()))
            loss = cfg['training']['loss_balance_ratio'] * loss_fn(
                input=outputs_FM,
                target=labels,
                weight=weight,
                size_average=cfg['training']['loss']['size_average']
            )  #Input:FM, Softmax is built with crossentropy loss fucntion

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts_network)

            print_per_batch_check = True if cfg['training'][
                'print_interval_per_batch'] else i_batch_idx + 1 == len(
                    trainloader)
            if (i_train_iter + 1) % cfg['training'][
                    'print_interval'] == 0 and print_per_batch_check:
                fmt_str = "Iter [{:d}/{:d}::{:d}/{:d}]  [Loss: {:.4f}]  NetworkTime/Image: {:.4f}"
                print_str = fmt_str.format(
                    i_train_iter + 1, cfg['training']['train_iters'],
                    i_batch_idx + 1, len(trainloader), loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                display(print_str)
                writer.add_scalar('loss/train_loss', loss.item(),
                                  i_train_iter + 1)
                time_meter.reset()
            i_batch_idx += 1
        entire_time_all_cases = time.time() - train_iter_start_time
        display(
            'EntireTime for {}th training iteration: {:.4f}   EntireTime/Image: {:.4f}'
            .format(
                i_train_iter + 1, entire_time_all_cases,
                entire_time_all_cases /
                (len(trainloader) * cfg['training']['batch_size'])))

        validation_check = (i_train_iter + 1) % cfg['training']['val_interval'] == 0 or \
                           (i_train_iter + 1) == cfg['training']['train_iters']
        if not validation_check:
            print('')
        else:
            model.eval()
            with torch.no_grad():
                for i_val, (images_val, labels_val,
                            case_index_list_val) in enumerate(valloader):
                    images_val = images_val.to(device)
                    labels_val = labels_val.to(device)

                    outputs_FM_val = model(images_val)
                    log(
                        'ValIter=> images_val.size():{} labels_val.size():{} | outputs.size():{}'
                        .format(images_val.size(), labels_val.size(),
                                outputs_FM_val.size())
                    )  #Input:FM, Softmax is built with crossentropy loss fucntion

                    val_loss = cfg['training']['loss_balance_ratio'] * loss_fn(
                        input=outputs_FM_val,
                        target=labels_val,
                        weight=weight,
                        size_average=cfg['training']['loss']['size_average'])

                    outputs_CLASS_val = outputs_FM_val.data.max(1)[1]
                    outputs_PROB_val = softmax_function(outputs_FM_val.data)
                    outputs_lesionPROB_val = outputs_PROB_val[:, 1, :, :, :]

                    running_metrics_val.update(labels_val.data.cpu().numpy(),
                                               outputs_CLASS_val.cpu().numpy())
                    val_loss_meter.update(val_loss.item())
                    '''
                        This FOR-LOOP is used to visualize validation data via tensorboard
                        It would take 3s roughly.
                    '''
                    for batch_identifier_index, case_index in enumerate(
                            case_index_list_val):
                        tensor_grid = []
                        image_val = images_val[
                            batch_identifier_index, :, :, :, :].float(
                            )  #torch.Size([3, 160, 160, 160])
                        label_val = labels_val[
                            batch_identifier_index, :, :, :].float(
                            )  #torch.Size([160, 160, 160])
                        output_lesionFM_val = outputs_FM_val[
                            batch_identifier_index,
                            1, :, :, :].float()  #torch.Size([160, 160, 160])
                        output_nonlesFM_val = outputs_FM_val[
                            batch_identifier_index,
                            0, :, :, :].float()  #torch.Size([160, 160, 160])
                        output_CLASS_val = outputs_CLASS_val[
                            batch_identifier_index, :, :, :].float(
                            )  #torch.Size([160, 160, 160])
                        output_lesionPROB_val = outputs_lesionPROB_val[
                            batch_identifier_index, :, :, :].float(
                            )  #torch.Size([160, 160, 160])
                        for z_index in range(images_val.size()[-1]):
                            label_slice = label_val[:, :, z_index]
                            output_CLASS_slice = output_CLASS_val[:, :,
                                                                  z_index]
                            if label_slice.sum(
                            ) == 0 and output_CLASS_slice.sum() == 0:
                                continue

                            image_slice = image_val[:, :, :, z_index]
                            output_nonlesFM_slice = output_nonlesFM_val[:, :,
                                                                        z_index]
                            output_lesionFM_slice = output_lesionFM_val[:, :,
                                                                        z_index]
                            output_lesionPROB_slice = output_lesionPROB_val[:, :,
                                                                            z_index]

                            label_slice = F.pad(label_slice.unsqueeze_(0),
                                                (0, 0, 0, 0, 1, 1))
                            output_CLASS_slice = F.pad(
                                output_CLASS_slice.unsqueeze_(0),
                                (0, 0, 0, 0, 2, 0))
                            output_nonlesFM_slice = output_nonlesFM_slice.unsqueeze_(
                                0).repeat(3, 1, 1)
                            output_lesionFM_slice = output_lesionFM_slice.unsqueeze_(
                                0).repeat(3, 1, 1)
                            output_lesionPROB_slice = output_lesionPROB_slice.unsqueeze_(
                                0).repeat(3, 1, 1)

                            slice_list = [
                                image_slice, output_nonlesFM_slice,
                                output_lesionFM_slice, output_lesionPROB_slice,
                                output_CLASS_slice, label_slice
                            ]
                            #slice_list = [image_slice, output_lesionFM_slice, output_lesionPROB_slice, output_CLASS_slice, label_slice]
                            slice_grid = make_grid(slice_list, padding=20)
                            tensor_grid.append(slice_grid)
                        if len(tensor_grid) == 0:
                            continue
                        tensorboard_image_tensor = make_grid(
                            tensor_grid,
                            nrow=int(math.sqrt(len(tensor_grid) / 6)) + 1,
                            padding=0).permute(1, 2, 0).cpu().numpy()
                        writer.add_image(case_index, tensorboard_image_tensor,
                                         i_train_iter + 1)
            writer.add_scalar('loss/val_loss', val_loss_meter.avg,
                              i_train_iter + 1)
            logger.info("Iter %d Loss_total: %.4f" %
                        (i_train_iter + 1, val_loss_meter.avg))
            '''
                This CODE-BLOCK is used to calculate and update the evaluation matrcs 
            '''
            score, class_iou = running_metrics_val.get_scores()
            print(
                '\x1b[1;32;44mValidationDataLoaded-EXPINDEX={}'.format(run_id))
            for k, v in score.items():
                print(k, v)
                logger.info('{}: {}'.format(k, v))
                if isinstance(v, list): continue
                writer.add_scalar('val_metrics/{}'.format(k), v,
                                  i_train_iter + 1)

            for k, v in class_iou.items():
                print('IOU:cls_{}:{}'.format(k, v))
                logger.info('{}: {}'.format(k, v))
                writer.add_scalar('val_metrics/cls_{}'.format(k), v,
                                  i_train_iter + 1)
            print('\x1b[0m\n')
            val_loss_meter.reset()
            running_metrics_val.reset()
            '''
                This IF-CHECK is used to update the best model
            '''
            if score["Mean IoU       : \t"] >= best_iou:
                #if score["Patch DICE AVER: \t"] >= best_iou:
                #best_iou = score["Patch DICE AVER: \t"]
                best_iou = score["Mean IoU       : \t"]

                state = {
                    "epoch": i_train_iter + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                  cfg['data']['dataset']))
                torch.save(state, save_path)
        i_train_iter += 1
コード例 #23
0
ファイル: train.py プロジェクト: HMellor/4YP_code
def train(cfg, writer, logger_old, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    if isinstance(cfg['training']['loss']['superpixels'], int):
        use_superpixels = True
        cfg['data']['train_split'] = 'train_super'
        cfg['data']['val_split'] = 'val_super'
        setup_superpixels(cfg['training']['loss']['superpixels'])
    elif cfg['training']['loss']['superpixels'] is not None:
        raise Exception(
            "cfg['training']['loss']['superpixels'] is of the wrong type")
    else:
        use_superpixels = False

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           superpixels=cfg['training']['loss']['superpixels'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        superpixels=cfg['training']['loss']['superpixels'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)
    running_metrics_train = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger_old.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger_old.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger_old.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger_old.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger_old.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    train_loss_meter = averageMeter()
    time_meter = averageMeter()

    train_len = t_loader.train_len
    val_static = 0
    best_iou = -100.0
    i = start_iter
    j = 0
    flag = True

    # Prepare logging
    xp_name = cfg['model']['arch'] + '_' + \
        cfg['training']['loss']['name'] + '_' + args.name
    xp = logger.Experiment(xp_name,
                           use_visdom=True,
                           visdom_opts={
                               'server': 'http://localhost',
                               'port': 8098
                           },
                           time_indexing=False,
                           xlabel='Epoch')
    # log the hyperparameters of the experiment
    xp.log_config(flatten(cfg))
    # create parent metric for training metrics (easier interface)
    xp.ParentWrapper(tag='train',
                     name='parent',
                     children=(xp.AvgMetric(name="loss"),
                               xp.AvgMetric(name='acc'),
                               xp.AvgMetric(name='acccls'),
                               xp.AvgMetric(name='fwavacc'),
                               xp.AvgMetric(name='meaniu')))
    xp.ParentWrapper(tag='val',
                     name='parent',
                     children=(xp.AvgMetric(name="loss"),
                               xp.AvgMetric(name='acc'),
                               xp.AvgMetric(name='acccls'),
                               xp.AvgMetric(name='fwavacc'),
                               xp.AvgMetric(name='meaniu')))
    best_loss = xp.BestMetric(tag='val-best', name='loss', mode='min')
    best_acc = xp.BestMetric(tag='val-best', name='acc')
    best_acccls = xp.BestMetric(tag='val-best', name='acccls')
    best_fwavacc = xp.BestMetric(tag='val-best', name='fwavacc')
    best_meaniu = xp.BestMetric(tag='val-best', name='meaniu')

    xp.plotter.set_win_opts(name="loss", opts={'title': 'Loss'})
    xp.plotter.set_win_opts(name="acc", opts={'title': 'Micro-Average'})
    xp.plotter.set_win_opts(name="acccls", opts={'title': 'Macro-Average'})
    xp.plotter.set_win_opts(name="fwavacc", opts={'title': 'FreqW Accuracy'})
    xp.plotter.set_win_opts(name="meaniu", opts={'title': 'Mean IoU'})

    it_per_step = cfg['training']['acc_batch_size']
    eff_batch_size = cfg['training']['batch_size'] * it_per_step
    while i <= train_len * (cfg['training']['epochs']) and flag:
        for (images, labels, labels_s, masks) in trainloader:
            i += 1
            j += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            labels_s = labels_s.to(device)
            masks = masks.to(device)

            outputs = model(images)
            if use_superpixels:
                outputs_s, labels_s, sizes = convert_to_superpixels(
                    outputs, labels_s, masks)
                loss = loss_fn(input=outputs_s, target=labels_s, size=sizes)
                outputs = convert_to_pixels(outputs_s, outputs, masks)
            else:
                loss = loss_fn(input=outputs, target=labels)

            # accumulate train metrics during train
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels.data.cpu().numpy()
            running_metrics_train.update(gt, pred)
            train_loss_meter.update(loss.item())

            if args.evaluate:
                decoded = t_loader.decode_segmap(np.squeeze(pred, axis=0))
                misc.imsave("./{}.png".format(i), decoded)
                image_save = np.transpose(
                    np.squeeze(images.data.cpu().numpy(), axis=0), (1, 2, 0))
                misc.imsave("./{}.jpg".format(i), image_save)

            # accumulate gradients based on the accumulation batch size
            if i % it_per_step == 1 or it_per_step == 1:
                optimizer.zero_grad()

            grad_rescaling = torch.tensor(1. / it_per_step).type_as(loss)
            loss.backward(grad_rescaling)
            if (i + 1) % it_per_step == 1 or it_per_step == 1:
                optimizer.step()
                optimizer.zero_grad()

            time_meter.update(time.time() - start_ts)
            # training logs
            if (j + 1) % (cfg['training']['print_interval'] *
                          it_per_step) == 0:
                fmt_str = "Epoch [{}/{}] Iter [{}/{:d}] Loss: {:.4f}  Time/Image: {:.4f}"
                total_iter = int(train_len / eff_batch_size)
                total_epoch = int(cfg['training']['epochs'])
                current_epoch = ceil((i + 1) / train_len)
                current_iter = int((j + 1) / it_per_step)
                print_str = fmt_str.format(
                    current_epoch, total_epoch, current_iter, total_iter,
                    loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger_old.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()
            # end of epoch evaluation
            if (i + 1) % train_len == 0 or \
               (i + 1) == train_len * (cfg['training']['epochs']):
                optimizer.step()
                optimizer.zero_grad()
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val, labels_val_s,
                                masks_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        labels_val_s = labels_val_s.to(device)
                        masks_val = masks_val.to(device)

                        outputs = model(images_val)
                        if use_superpixels:
                            outputs_s, labels_val_s, sizes_val = convert_to_superpixels(
                                outputs, labels_val_s, masks_val)
                            val_loss = loss_fn(input=outputs_s,
                                               target=labels_val_s,
                                               size=sizes_val)
                            outputs = convert_to_pixels(
                                outputs_s, outputs, masks_val)
                        else:
                            val_loss = loss_fn(input=outputs,
                                               target=labels_val)
                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                writer.add_scalar('loss/train_loss', train_loss_meter.avg,
                                  i + 1)
                logger_old.info("Epoch %d Val Loss: %.4f" % (int(
                    (i + 1) / train_len), val_loss_meter.avg))
                logger_old.info("Epoch %d Train Loss: %.4f" % (int(
                    (i + 1) / train_len), train_loss_meter.avg))

                score, class_iou = running_metrics_train.get_scores()
                print("Training metrics:")
                for k, v in score.items():
                    print(k, v)
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('train_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('train_metrics/cls_{}'.format(k), v,
                                      i + 1)

                xp.Parent_Train.update(loss=train_loss_meter.avg,
                                       acc=score['Overall Acc: \t'],
                                       acccls=score['Mean Acc : \t'],
                                       fwavacc=score['FreqW Acc : \t'],
                                       meaniu=score['Mean IoU : \t'])

                score, class_iou = running_metrics_val.get_scores()
                print("Validation metrics:")
                for k, v in score.items():
                    print(k, v)
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                xp.Parent_Val.update(loss=val_loss_meter.avg,
                                     acc=score['Overall Acc: \t'],
                                     acccls=score['Mean Acc : \t'],
                                     fwavacc=score['FreqW Acc : \t'],
                                     meaniu=score['Mean IoU : \t'])

                xp.Parent_Val.log_and_reset()
                xp.Parent_Train.log_and_reset()
                best_loss.update(xp.loss_val).log()
                best_acc.update(xp.acc_val).log()
                best_acccls.update(xp.acccls_val).log()
                best_fwavacc.update(xp.fwavacc_val).log()
                best_meaniu.update(xp.meaniu_val).log()

                visdir = os.path.join('runs', cfg['training']['loss']['name'],
                                      args.name, 'plots.json')
                xp.to_json(visdir)

                val_loss_meter.reset()
                train_loss_meter.reset()
                running_metrics_val.reset()
                running_metrics_train.reset()
                j = 0

                if score["Mean IoU : \t"] >= best_iou:
                    val_static = 0
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)
                else:
                    val_static += 1

            if (i + 1) == train_len * (
                    cfg['training']['epochs']) or val_static == 10:
                flag = False
                break
    return best_iou
コード例 #24
0
def train(cfg, writer, logger, run_id):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    torch.backends.cudnn.benchmark = True

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    # model = get_model(cfg['model'], n_classes).to(device)
    model = get_model(cfg['model'], n_classes)
    logger.info("Using Model: {}".format(cfg['model']['arch']))

    # model=apex.parallel.convert_syncbn_model(model)
    model = model.to(device)

    # a=range(torch.cuda.device_count())
    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    # model = torch.nn.DataParallel(model, device_ids=[0,1])
    # model = encoding.parallel.DataParallelModel(model, device_ids=[0, 1])

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)

    # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0)

    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0)

    loss_fn = get_loss_function(cfg)
    # loss_fn== encoding.parallel.DataParallelCriterion(loss_fn, device_ids=[0, 1])
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()
    time_meter_val = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    train_data_len = t_loader.__len__()
    batch_size = cfg['training']['batch_size']
    epoch = cfg['training']['train_epoch']
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)

    val_rlt_f1 = []
    val_rlt_IoU = []
    best_f1_till_now = 0
    best_IoU_till_now = 0

    while i <= train_iter and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            # optimizer.backward(loss)

            optimizer.step()

            time_meter.update(time.time() - start_ts)

            ### add by Sprit
            time_meter_val.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, train_iter, loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == train_iter:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        # val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        # val_loss_meter.update(val_loss.item())

                # writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
                # logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()

                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    # writer.add_scalar('val_metrics/{}'.format(k), v, i+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    # writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1)

                # val_loss_meter.reset()
                running_metrics_val.reset()

                ### add by Sprit
                avg_f1 = score["Mean F1 : \t"]
                avg_IoU = score["Mean IoU : \t"]
                val_rlt_f1.append(avg_f1)
                val_rlt_IoU.append(score["Mean IoU : \t"])

                if avg_f1 >= best_f1_till_now:
                    best_f1_till_now = avg_f1
                    correspond_iou = score["Mean IoU : \t"]
                    best_epoch_till_now = i + 1
                print("\nBest F1 till now = ", best_f1_till_now)
                print("Correspond IoU= ", correspond_iou)
                print("Best F1 Iter till now= ", best_epoch_till_now)

                if avg_IoU >= best_IoU_till_now:
                    best_IoU_till_now = avg_IoU
                    correspond_f1 = score["Mean F1 : \t"]
                    correspond_acc = score["Overall Acc: \t"]
                    best_epoch_till_now = i + 1
                print("Best IoU till now = ", best_IoU_till_now)
                print("Correspond F1= ", correspond_f1)
                print("Correspond OA= ", correspond_acc)
                print("Best IoU Iter till now= ", best_epoch_till_now)

                ### add by Sprit
                iter_time = time_meter_val.avg
                time_meter_val.reset()
                remain_time = iter_time * (train_iter - i)
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain training time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain training time : Training completed.\n"
                print(train_time)

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == train_iter:
                flag = False
                break
    my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_f1,
                  cfg['training']['val_interval'])
    my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_IoU,
                  cfg['training']['val_interval'])
コード例 #25
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
        n_classes=20,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])
    # -----------------------------------------------------------------
    # Setup Metrics (substract one class)
    running_metrics_val = runningScore(n_classes - 1)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()

    # get loss_seg meter and also loss_dep meter

    loss_seg_meter = averageMeter()
    loss_dep_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, masks, depths) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            depths = depths.to(device)

            #print(images.shape)
            optimizer.zero_grad()
            outputs = model(images)
            #print('depths size: ', depths.size())
            #print('output shape: ', outputs.shape)

            loss_seg = loss_fn(input=outputs[:, :-1, :, :], target=labels)

            # -----------------------------------------------------------------
            # add depth loss

            # -----------------------------------------------------------------
            # MSE loss
            # loss_dep = F.mse_loss(input=outputs[:, -1,:,:], target=depths, reduction='mean')

            # -----------------------------------------------------------------
            # Berhu loss
            loss_dep = berhu_loss_function(prediction=outputs[:, -1, :, :],
                                           target=depths)
            #loss_dep = loss_dep.type(torch.cuda.ByteTensor)
            masks = masks.type(torch.cuda.ByteTensor)
            loss_dep = torch.sum(loss_dep[masks]) / torch.sum(masks)
            print('loss depth', loss_dep)
            loss = loss_dep + loss_seg
            # -----------------------------------------------------------------

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  loss_seg: {:.4f}  loss_dep: {:.4f}  overall loss: {:.4f}   Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg["training"]["train_iters"], loss_seg.item(),
                    loss_dep.item(), loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:

                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val, masks_val,
                                depths_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        print('images_val shape', images_val.size())
                        # add depth to device
                        depths_val = depths_val.to(device)

                        outputs = model(images_val)
                        #depths_val = depths_val.data.resize_(depths_val.size(0), outputs.size(2), outputs.size(3))

                        # -----------------------------------------------------------------
                        # loss function for segmentation
                        print('output shape', outputs.size())
                        val_loss_seg = loss_fn(input=outputs[:, :-1, :, :],
                                               target=labels_val)

                        # -----------------------------------------------------------------
                        # MSE loss
                        # val_loss_dep = F.mse_loss(input=outputs[:, -1, :, :], target=depths_val, reduction='mean')

                        # -----------------------------------------------------------------
                        # berhu loss function
                        val_loss_dep = berhu_loss_function(
                            prediction=outputs[:, -1, :, :], target=depths_val)
                        val_loss_dep = val_loss_dep.type(torch.cuda.ByteTensor)
                        masks_val = masks_val.type(torch.cuda.ByteTensor)
                        val_loss_dep = torch.sum(
                            val_loss_dep[masks_val]) / torch.sum(masks_val)
                        val_loss = loss_dep + loss_seg
                        # -----------------------------------------------------------------

                        prediction = outputs[:, :-1, :, :]
                        prediction = prediction.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        # adapt metrics to seg and dep
                        running_metrics_val.update(gt, prediction)
                        loss_seg_meter.update(val_loss_seg.item())
                        loss_dep_meter.update(val_loss_dep.item())

                        # -----------------------------------------------------------------
                        # get rid of val_loss_meter
                        # val_loss_meter.update(val_loss.item())
                        # writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                        # logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
                        # -----------------------------------------------------------------

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                print("Segmentation loss is {}".format(loss_seg_meter.avg))
                logger.info("Segmentation loss is {}".format(
                    loss_seg_meter.avg))
                #writer.add_scalar("Segmentation loss is {}".format(loss_seg_meter.avg), i + 1)

                print("Depth loss is {}".format(loss_dep_meter.avg))
                logger.info("Depth loss is {}".format(loss_dep_meter.avg))
                #writer.add_scalar("Depth loss is {}".format(loss_dep_meter.avg), i + 1)

                val_loss_meter.reset()
                loss_seg_meter.reset()
                loss_dep_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

                    # insert print function to see if the losses are correct

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
コード例 #26
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        sbd_path=cfg["data"]["sbd_path"],
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        sbd_path=cfg["data"]["sbd_path"],
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
コード例 #27
0
def train(cfg, logger):
    
    # Setup seeds   ME: take these out for random samples
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("DEVICE: ",device)

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    
    if torch.cuda.is_available():
        data_path = cfg['data']['server_path']
    else:
        data_path = cfg['data']['path']
    
    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        augmentations=data_aug)
    
    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'], 
                                  num_workers=cfg['training']['n_workers'], 
                                  shuffle=True)

    number_of_images_training = t_loader.number_of_images
    
    # Setup Hierarchy
    
    if torch.cuda.is_available():
        if cfg['data']['dataset'] == "vistas":
            if cfg['data']['viking']:
                root = create_tree_from_textfile("/users/brm512/scratch/experiments/meetshah-semseg/mapillary_tree.txt")
            else:
                root = create_tree_from_textfile("/home/userfs/b/brm512/experiments/meetshah-semseg/mapillary_tree.txt")
        elif cfg['data']['dataset'] == "faces":
            if cfg['data']['viking']:
                root = create_tree_from_textfile("/users/brm512/scratch/experiments/meetshah-semseg/faces_tree.txt")
            else:
                root = create_tree_from_textfile("/home/userfs/b/brm512/experiments/meetshah-semseg/faces_tree.txt")
    else:
        if cfg['data']['dataset'] == "vistas":
            root = create_tree_from_textfile("/home/brm512/Pytorch/meetshah-semseg/mapillary_tree.txt")
        elif cfg['data']['dataset'] == "faces":
            root = create_tree_from_textfile("/home/brm512/Pytorch/meetshah-semseg/faces_tree.txt")

    add_channels(root,0)
    add_levels(root,find_depth(root))
    
    class_lookup = [0,10,7,8,9,1,6,4,5,2,3]  # correcting for tree channel and data integer class correspondence  # HELEN
    #class_lookup = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,48,51,45,46,47,49,50,52,53,54,55,56,57,58,59,60,61,62,63,64,65] # VISTAS
    update_channels(root, class_lookup)

    # Setup models for Hierarchical and Standard training. Note we use Tree synonymously with hierarchy

    model_nontree = get_model(cfg['model'], n_classes).to(device)
    model_tree = get_model(cfg['model'], n_classes).to(device)
    model_nontree = torch.nn.DataParallel(model_nontree, device_ids=range(torch.cuda.device_count()))
    model_tree = torch.nn.DataParallel(model_tree, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls_nontree = get_optimizer(cfg)
    optimizer_params_nontree = {k:v for k, v in cfg['training']['optimizer'].items() if k != 'name'}
    optimizer_nontree = optimizer_cls_nontree(model_nontree.parameters(), **optimizer_params_nontree)
    logger.info("Using non tree optimizer {}".format(optimizer_nontree))

    optimizer_cls_tree = get_optimizer(cfg)
    optimizer_params_tree = {k:v for k, v in cfg['training']['optimizer'].items() 
                        if k != 'name'}
    optimizer_tree = optimizer_cls_tree(model_tree.parameters(), **optimizer_params_tree)
    logger.info("Using non tree optimizer {}".format(optimizer_tree))
    
    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    loss_meter_nontree = averageMeter()
    if cfg['training']['use_hierarchy']:
        loss_meter_level0_nontree = averageMeter()
        loss_meter_level1_nontree = averageMeter()
        loss_meter_level2_nontree = averageMeter()
        loss_meter_level3_nontree = averageMeter()
        
    loss_meter_tree = averageMeter()
    if cfg['training']['use_hierarchy']:
        loss_meter_level0_tree = averageMeter()
        loss_meter_level1_tree = averageMeter()
        loss_meter_level2_tree = averageMeter()
        loss_meter_level3_tree = averageMeter()
        
        
    time_meter = averageMeter()
    epoch = 0
    i = 0
    flag = True
    number_epoch_iters = number_of_images_training / cfg['training']['batch_size']
    
# TRAINING
    start_training_time = time.time()
    
    while i < cfg['training']['train_iters'] and flag and epoch < cfg['training']['epochs']:
       
        epoch_start_time = time.time()
        epoch = epoch + 1
        for (images, labels) in trainloader:
            i = i + 1
            start_ts = time.time()
        
            model_nontree.train()
            model_tree.train()
            
            images = images.to(device)
            labels = labels.to(device)

            optimizer_nontree.zero_grad()
            optimizer_tree.zero_grad()
            
            outputs_nontree = model_nontree(images)
            outputs_tree = model_tree(images)

            #nontree loss calculation
            if cfg['training']['use_tree_loss']:
                loss_nontree = loss_fn(input=outputs_nontree, target=labels, root=root, use_hierarchy = cfg['training']['use_hierarchy'])
                level_losses_nontree = loss_nontree[1]
                mainloss_nontree = loss_fn(input=outputs_nontree, target=labels, root=root, use_hierarchy = False)[0]
            else:
                loss_nontree = loss_fn(input=outputs_nontree, target=labels)
                mainloss_nontree = loss_nontree
            
            #tree loss calculation
            if cfg['training']['use_tree_loss']:
                loss_tree = loss_fn(input=outputs_tree, target=labels, root=root, use_hierarchy = cfg['training']['use_hierarchy'])
                level_losses_tree = loss_tree[1]
                mainloss_tree = loss_tree[0]
            else:
                loss_tree = loss_fn(input=outputs_tree, target=labels)
                mainloss_tree = loss_tree
            
            loss_meter_nontree.update(mainloss_nontree.item())
            if cfg['training']['use_hierarchy'] and not cfg['training']['phased']:
                loss_meter_level0_nontree.update(level_losses_nontree[0])
                loss_meter_level1_nontree.update(level_losses_nontree[1])
                loss_meter_level2_nontree.update(level_losses_nontree[2])
                loss_meter_level3_nontree.update(level_losses_nontree[3])
                
            loss_meter_tree.update(mainloss_tree.item())
            if cfg['training']['use_hierarchy'] and not cfg['training']['phased']:
                loss_meter_level0_tree.update(level_losses_tree[0])
                loss_meter_level1_tree.update(level_losses_tree[1])
                loss_meter_level2_tree.update(level_losses_tree[2])
                loss_meter_level3_tree.update(level_losses_tree[3])

            # optimise nontree and tree
            mainloss_nontree.backward()
            mainloss_tree.backward()
            
            optimizer_nontree.step()
            optimizer_tree.step()

            time_meter.update(time.time() - start_ts)
            
            # For printing/logging stats
            if (i) % cfg['training']['print_interval'] == 0:
                fmt_str = "Epoch [{:d}/{:d}] Iter [{:d}/{:d}] IterNonTreeLoss: {:.4f}  IterTreeLoss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(epoch,cfg['training']['epochs'], i % int(number_epoch_iters),
                                           int(number_epoch_iters), mainloss_nontree.item(), 
                                           mainloss_tree.item(),
                                           time_meter.avg / cfg['training']['batch_size'])
    

                print(print_str)
                logger.info(print_str)
                time_meter.reset()
                
# VALIDATION AFTER EVERY EPOCH
            if (i) % cfg['training']['val_interval'] == 0 or (i) % number_epoch_iters == 0 or (i) == cfg['training']['train_iters']:
                validate(cfg, model_nontree, model_tree, loss_fn, device, root)
                # reset meters after validation
                loss_meter_nontree.reset()
                if cfg['training']['use_hierarchy']:
                    loss_meter_level0_nontree.reset()
                    loss_meter_level1_nontree.reset()
                    loss_meter_level2_nontree.reset()
                    loss_meter_level3_nontree.reset()

                loss_meter_tree.reset()     
                if cfg['training']['use_hierarchy']:
                    loss_meter_level0_tree.reset()
                    loss_meter_level1_tree.reset()
                    loss_meter_level2_tree.reset()
                    loss_meter_level3_tree.reset()
            
            # For de-bugging
            if (i) == cfg['training']['train_iters']:
                flag = False
                break
            
        print("EPOCH TIME (MIN): ", epoch, (time.time() - epoch_start_time)/60.0)
        logger.info("Epoch %d took %.4f minutes" % (int(epoch) , (time.time() - epoch_start_time)/60.0))
           
    print("TRAINING TIME: ",(time.time() - start_training_time)/3600.0)