示例#1
0
文件: demo.py 项目: hepuzheng/GDesign
def demo(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # output folder
    if not os.path.exists(config.outdir):
        os.makedirs(config.outdir)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    model = get_model(args.model, pretrained=True,
                      root=args.save_folder).to(device)
    print('Finished loading model!')

    if args.input_pic != None:
        image = Image.open(config.input_pic).convert('RGB')
        images = transform(image).unsqueeze(0).to(device)
        test(model, images, args.input_pic)
    else:
        # image transform
        test_dataset = get_segmentation_dataset(args.dataset,
                                                split='test',
                                                mode='test',
                                                transform=transform)
        test_sampler = make_data_sampler(test_dataset, True, False)
        test_batch_sampler = make_batch_data_sampler(test_sampler,
                                                     images_per_batch=1)
        test_loader = data.DataLoader(dataset=test_dataset,
                                      batch_sampler=test_batch_sampler,
                                      num_workers=4,
                                      pin_memory=True)
        for i, (image, target) in enumerate(test_loader):
            image = image.to(torch.device(device))
            test(model, image, ''.join(target))
示例#2
0
def demo(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # output folder
    if not os.path.exists(config.outdir):
        os.makedirs(config.outdir)

    # image transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image = Image.open(config.input_pic).convert('RGB')
    images = transform(image).unsqueeze(0).to(device)

    model = get_model(args.model, pretrained=True, root=args.save_folder).to(device)
    print('Finished loading model!')

    model.eval()
    with torch.no_grad():
        output = model(images)

    pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy()
    mask = get_color_pallete(pred, args.dataset)
    outname = os.path.splitext(os.path.split(args.input_pic)[-1])[0] + '.png'
    mask.save(os.path.join(args.outdir, outname))
示例#3
0
def demo(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # output folder
    if not os.path.exists(config.outdir):
        os.makedirs(config.outdir)

    # image transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image = Image.open(config.input_pic).convert('RGB')
    images = transform(image).unsqueeze(0).to(device)

    if args.resume:
        if os.path.isfile(args.resume):
            name, ext = os.path.splitext(args.resume)
            assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
            print('Resuming training, loading {}...'.format(args.resume))
            self.model.load_state_dict(
                torch.load(args.resume,
                           map_location=lambda storage, loc: storage))

    model = get_model(args.model, pretrained=True,
                      root=args.save_folder).to(device)
    print('Finished loading model!')

    model.eval()
    with torch.no_grad():
        output = model(images)

    pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy()
    mask = get_color_pallete(pred, args.dataset)
    outname = os.path.splitext(os.path.split(args.input_pic)[-1])[0] + '.png'
    mask.save(os.path.join(args.outdir, outname))
示例#4
0
def demo(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # output folder
    if not os.path.exists(config.outdir):
        os.makedirs(config.outdir)

    # image transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image = Image.open(config.input_pic).convert('RGB')
    images = transform(image).unsqueeze(0).to(device)

    model = get_model(args.model, pretrained=True,
                      root=args.save_folder).to(device)
    print('Finished loading model!')

    model.eval()
    with torch.no_grad():
        output = model(images)

    pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy()
    print('predicted masks : ', np.unique(pred))

    mask = get_color_pallete(pred, args.dataset)
    outname = os.path.splitext(os.path.split(args.input_pic)[-1])[0] + '.png'
    mask.save(os.path.join(args.outdir, outname))
    mask = cv2.imread(os.path.join(args.outdir, outname), cv2.IMREAD_COLOR)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
    blended = cv2.addWeighted(np.array(image), 0.5, mask, 0.5, 0.0)
    blended = cv2.cvtColor(blended, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(args.outdir, outname), blended)
 def load_model(self):
     self.logger.info('> load_model')
     self.model = get_model(self.cfg["model"],
                            self.cfg["data"]["num_classes"])
     self.model.load_state_dict(
         convert_state_dict(
             torch.load(self.cfg['test']['checkpoint'])["model_state"]))
     self.model.to(self.device)
示例#6
0
文件: eval.py 项目: lilujunai/MTLNAS
def main():
    parser = argparse.ArgumentParser(description="Baseline Experiment Eval")
    parser.add_argument(
        "--config-file",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    # Seeding
    random.seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # This can slow down training

    # load the data
    test_loader = torch.utils.data.DataLoader(get_dataset(cfg, 'test'),
                                              batch_size=cfg.TEST.BATCH_SIZE,
                                              shuffle=False,
                                              pin_memory=True)

    task1, task2 = get_tasks(cfg)
    model = get_model(cfg, task1, task2)

    ckpt_path = os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                             'ckpt-%s.pth' % str(cfg.TEST.CKPT_ID).zfill(5))
    print("Evaluating Checkpoint at %s" % ckpt_path)
    ckpt = torch.load(ckpt_path)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    if cfg.CUDA:
        model = model.cuda()

    task1_metric, task2_metric = evaluate(test_loader, model, task1, task2)
    for k, v in task1_metric.items():
        print('{}: {:.3f}'.format(k, v))
    for k, v in task2_metric.items():
        print('{}: {:.3f}'.format(k, v))
示例#7
0
def demo(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # output folder
    if not os.path.exists(config.outdir):
        os.makedirs(config.outdir)

    # image transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    #read the video file here
    model = get_model(args.model, pretrained=True, root=args.save_folder).to(device)
    model.eval()
    print('Finished loading model!')
    count = 0
    pbar = tqdm(total=150)
    while cap.isOpened():
        count += 1
        ret, image = cap.read()
        # image = Image.open(config.input_pic).convert('RGB')
        try:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # RGB
            image = cv2.resize(image, (640, 480))
            # image = cv2.flip(image, 1)
        except: continue
        image = cv2.GaussianBlur(image, (5, 5), 0)
        images = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(images)

        pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy()
        mask = get_color_pallete(pred, args.dataset)
        # print(mask.shape)
        # print('type is :: ',type(mask))
        outname = os.path.splitext(os.path.split('tmp')[-1])[0] + '.png'
        mask.save(os.path.join(args.outdir, outname))
        mask = cv2.imread(os.path.join(args.outdir, outname)) #in BGR
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        blended = cv2.addWeighted(image, 0.5, mask, 0.5, 0.0)
        if args.display:
            cv2.imshow('output', blended)
            cv2.waitKey(1)
        out.write(blended)
        pbar.update(1)

        # if count==300: break

    cap.release()
    out.release()
    print('Done. Video file generated')
示例#8
0
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Generate the train and validation sets for the model:
    split_train_val(args, per_val=args.per_val)

    current_time = datetime.now().strftime('%b%d_%H%M%S')
    log_dir = os.path.join(
        'runs', current_time + "_{}_{}".format(args.arch, args.loss))
    writer = SummaryWriter(log_dir=log_dir)
    # Setup Augmentations
    if args.aug:
        data_aug = Compose(
            [RandomRotate(30),
             RandomHorizontallyFlip(),
             AddNoise()])
    else:
        data_aug = None

    train_set = patch_loader(is_transform=True,
                             split='train',
                             stride=args.stride,
                             patch_size=args.patch_size,
                             augmentations=data_aug)

    # Without Augmentation:
    val_set = patch_loader(is_transform=True,
                           split='val',
                           stride=args.stride,
                           patch_size=args.patch_size)

    n_classes = train_set.n_classes

    trainloader = data.DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)
    valloader = data.DataLoader(val_set,
                                batch_size=args.batch_size,
                                num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        model = get_model(args.arch, args.pretrained, n_classes)

    # Use as many GPUs as we can
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model = model.to(device)  # Send to GPU

    # PYTROCH NOTE: ALWAYS CONSTRUCT OPTIMIZERS AFTER MODEL IS PUSHED TO GPU/CPU,

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adadelta(model.parameters())
        optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

    if (args.loss == 'FL'):
        loss_fn = core.loss.focal_loss2d
    else:
        loss_fn = core.loss.cross_entropy

    if args.class_weights:
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None

    best_iou = -100.0
    class_names = [
        'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk', 'scruff',
        'zechstein'
    ]

    for arg in vars(args):
        text = arg + ': ' + str(getattr(args, arg))
        writer.add_text('Parameters/', text)

    # training
    for epoch in range(args.n_epoch):
        # Training Mode:
        model.train()
        loss_train, total_iteration = 0, 0

        for i, (images, labels) in enumerate(trainloader):
            image_original, labels_original = images, labels
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            pred = outputs.detach().max(1)[1].cpu().numpy()
            gt = labels.detach().cpu().numpy()
            running_metrics.update(gt, pred)

            loss = loss_fn(input=outputs, target=labels, weight=class_weights)
            loss_train += loss.item()
            loss.backward()

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()
            total_iteration = total_iteration + 1

            if (i) % 20 == 0:
                print("Epoch [%d/%d] training Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.item()))

            numbers = [0]
            if i in numbers:
                # number 0 image in the batch
                tb_original_image = vutils.make_grid(image_original[0][0],
                                                     normalize=True,
                                                     scale_each=True)
                writer.add_image('train/original_image', tb_original_image,
                                 epoch + 1)

                labels_original = labels_original.numpy()[0]
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original))
                writer.add_image('train/original_label', correct_label_decoded,
                                 epoch + 1)
                out = F.softmax(outputs, dim=1)

                # this returns the max. channel number:
                prediction = out.max(1)[1].cpu().numpy()[0]
                # this returns the confidence:
                confidence = out.max(1)[0].cpu().detach()[0]
                tb_confidence = vutils.make_grid(confidence,
                                                 normalize=True,
                                                 scale_each=True)

                decoded = train_set.decode_segmap(np.squeeze(prediction))
                writer.add_image('train/predicted', decoded, epoch + 1)
                writer.add_image('train/confidence', tb_confidence, epoch + 1)

                unary = outputs.cpu().detach()
                unary_max = torch.max(unary)
                unary_min = torch.min(unary)
                unary = unary.add((-1 * unary_min))
                unary = unary / (unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][channel]
                    tb_channel = vutils.make_grid(decoded_channel,
                                                  normalize=True,
                                                  scale_each=True)
                    writer.add_image(f'train_classes/_{class_names[channel]}',
                                     tb_channel, epoch + 1)

        # Average metrics, and save in writer()
        loss_train /= total_iteration
        score, class_iou = running_metrics.get_scores()
        writer.add_scalar('train/Pixel Acc', score['Pixel Acc: '], epoch + 1)
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch + 1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch + 1)
        running_metrics.reset()
        writer.add_scalar('train/loss', loss_train, epoch + 1)

        if args.per_val != 0:
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()
                loss_val, total_iteration_val = 0, 0

                for i_val, (images_val,
                            labels_val) in tqdm(enumerate(valloader)):
                    image_original, labels_original = images_val, labels_val
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(device)

                    outputs_val = model(images_val)
                    pred = outputs_val.detach().max(1)[1].cpu().numpy()
                    gt = labels_val.detach().cpu().numpy()

                    running_metrics_val.update(gt, pred)

                    loss = loss_fn(input=outputs_val, target=labels_val)

                    total_iteration_val = total_iteration_val + 1

                    if (i_val) % 20 == 0:
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:
                        # number 0 image in the batch
                        tb_original_image = vutils.make_grid(
                            image_original[0][0],
                            normalize=True,
                            scale_each=True)
                        writer.add_image('val/original_image',
                                         tb_original_image, epoch)
                        labels_original = labels_original.numpy()[0]
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original))
                        writer.add_image('val/original_label',
                                         correct_label_decoded, epoch + 1)

                        out = F.softmax(outputs_val, dim=1)

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy()[0]
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach()[0]
                        tb_confidence = vutils.make_grid(confidence,
                                                         normalize=True,
                                                         scale_each=True)

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction))
                        writer.add_image('val/predicted', decoded, epoch + 1)
                        writer.add_image('val/confidence', tb_confidence,
                                         epoch + 1)

                        unary = outputs.cpu().detach()
                        unary_max, unary_min = torch.max(unary), torch.min(
                            unary)
                        unary = unary.add((-1 * unary_min))
                        unary = unary / (unary_max - unary_min)

                        for channel in range(0, len(class_names)):
                            tb_channel = vutils.make_grid(unary[0][channel],
                                                          normalize=True,
                                                          scale_each=True)
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}',
                                tb_channel, epoch + 1)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)

                writer.add_scalar('val/Pixel Acc', score['Pixel Acc: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch + 1)

                writer.add_scalar('val/loss', loss.item(), epoch + 1)
                running_metrics_val.reset()

                if score['Mean IoU: '] >= best_iou:
                    best_iou = score['Mean IoU: ']
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_{args.loss}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model:
            if (epoch + 1) % 5 == 0:
                model_dir = os.path.join(
                    log_dir, f"{args.arch}_{args.loss}_ep{epoch+1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()
示例#9
0
def demo(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # output folder
    if not os.path.exists(config.outdir):
        os.makedirs(config.outdir)

    # image transform
    transform = transforms.Compose([
        # transforms.Resize([320, 320]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    model = get_model(args.model,
                      pretrained=True,
                      root=args.save_folder,
                      local_rank=args.local_rank).to(device)
    print('Finished loading model!')

    model.eval()

    # file_dir = '/train/trainset/1/self-built-masked-face-recognition-dataset/AFDB_masked_face_dataset'
    # file_dirs = os.listdir(file_dir)
    # for d, dir in enumerate(file_dirs):
    #     files = os.listdir(os.path.join(file_dir, dir))
    #     for i, file in enumerate(files):
    #         print("%d haved done" % i)
    #         with torch.no_grad():
    #             path = os.path.join(file_dir, dir, file)
    #             ori_image = cv2.imread(path)
    #             # ori_image = cv2.resize(ori_image, (400, 400), interpolation=cv2.INTER_CUBIC)
    #             ori_image = ori_image.astype(np.float32)
    #             image = Image.open(path).convert('RGB')
    #             images = transform(image).unsqueeze(0).to(device)
    #             output = model(images)
    #
    #         pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy()
    #
    #         colormap = dict2array(VOC21)
    #         parsing_color = colormap[pred.astype(np.int)]
    #
    #         idx = np.nonzero(pred)
    #         ori_image[idx[0], idx[1], :] *= 1.0 - 0.4
    #         ori_image += 0.4 * parsing_color
    #
    #         ori_image = ori_image.astype(np.uint8)
    #
    #         outname = os.path.basename(path).replace('.jpg', '.png')
    #         if not os.path.exists(os.path.join(config.outdir, dir)):
    #             os.makedirs(os.path.join(config.outdir, dir))
    #         cv2.imwrite(os.path.join(config.outdir, dir, outname), ori_image)

    files = os.listdir('/train/trainset/1/img_align')
    for i, file in enumerate(files):
        print("%d haved done" % i)
        with torch.no_grad():
            # image = Image.open(config.input_pic).convert('RGB')
            path = os.path.join('/train/trainset/1/img_align', file)
            ori_image = cv2.imread(path)
            ori_image = ori_image.astype(np.float32)
            out_size = args.crop_size
            h = ori_image.shape[0]
            w = ori_image.shape[1]
            x1 = int(round((w - out_size) / 2.))
            y1 = int(round((h - out_size) / 2.))
            ori_image = ori_image[x1:x1 + out_size, y1:y1 + out_size, :]
            image = Image.open(path).convert('RGB')
            image = center_crop(image)
            images = transform(image).unsqueeze(0).to(device)
            output = model(images)

        pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy()

        colormap = dict2array(VOC21)
        parsing_color = colormap[pred.astype(np.int)]

        idx = np.nonzero(pred)
        ori_image[idx[0], idx[1], :] *= 1.0 - 0.4
        ori_image += 0.4 * parsing_color

        ori_image = ori_image.astype(np.uint8)

        outname = os.path.basename(path).replace('.jpg', '.png')
        cv2.imwrite(os.path.join(args.outdir, outname), ori_image)
示例#10
0
def train(cfg, writer, logger):
    # Setup seeds
    seed = cfg['data'].get('seed', 1336)
    setup_seeds(seed)

    # Setup device
    device = setup_device(cfg.get('gpus', '0'))

    # Setup Augmentations
    train_aug = get_composed_augmentations(cfg["augmentations"].get("train_augmentations", None))
    valid_aug = get_composed_augmentations(cfg["augmentations"].get("valid_augmentations", None))

    # Setup Dataloader
    dataset_cls = get_dataset(cfg["data"]["dataset"])
    train_label_path_list = cfg["data"]["train_label_path_list"]
    valid_label_path = cfg["data"]["valid_label_path"]
    num_classes = cfg["data"]["num_classes"]
    batch_size = cfg["train"]["batch_size"]
    num_workers = cfg["train"]["n_workers"]
    data_root = cfg["data"]["data_root"]
    x_key = cfg["data"]["x_key"]
    # [train_class, test_class] = split_dataset_by_csv(train_label_path_list, valid_label_path, x_key=x_key)

    train_loader, valid_loader = setup_dataloader(
        dataset_cls, train_label_path_list, valid_label_path,
        x_key=x_key, data_root=data_root,
        batch_size=batch_size,
        num_workers=num_workers, train_aug=train_aug, valid_aug=valid_aug)

    logger.info('len(train_loader): {}'.format(len(train_loader)))
    logger.info('len(valid_loader): {}'.format(len(valid_loader)))

    # Setup Model
    model = get_model(cfg["model"], num_classes).to(device)
    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.train()

    logger.info('>> Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # Setup optimizer, lr_scheduler
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k: v for k, v in cfg["train"]["optimizer"].items() if k != "name"}
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    scheduler_name = cfg["train"]["lr_schedule"]["name"]
    scheduler = get_scheduler(optimizer, cfg["train"]["lr_schedule"])

    # Setup loss function
    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_epoch = 1
    if cfg["train"]["resume"] is not None:
        if os.path.isfile(cfg["train"]["resume"]):
            logger.info("Loading model and optimizer from checkpoint '{}'".format(cfg["train"]["resume"]))
            checkpoint = torch.load(cfg["train"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_epoch = checkpoint["epoch"] + 1
            logger.info("Loaded checkpoint '{}' (epoch {})".format(cfg["train"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(cfg["train"]["resume"]))

    ####################################################################################################################
    #  epoch
    ####################################################################################################################
    curr_epoch = start_epoch
    valid_step_list = list(np.linspace(0, len(train_loader), num=cfg['train']['n_valid_per_epoch']+1, endpoint=True))[1:]
    valid_step_list = [int(step) for step in valid_step_list]
    print('valid_step_list', valid_step_list)
    while curr_epoch <= cfg["train"]["n_epoch"]:
        start_ts = time.time()
        train_q_loss_meter = AverageMeter()
        train_a_loss_meter = AverageMeter()
        valid_q_loss_meter = AverageMeter()
        valid_iou_meter = AverageMeter()
        valid_acc_meter = AverageMeter()
        valid_f1_meter = AverageMeter()
        ################################################################################################################
        #  train
        ################################################################################################################
        for train_i, (train_support, train_smasks_fg, train_smasks_bg, train_query, train_qmask, _) in enumerate(tqdm.tqdm(train_loader)):
            model.train()
            # print('train_support', train_support.shape)
            # print('train_smasks_fg', train_smasks_fg.shape)
            # print('train_query', train_query.shape)
            # print('train_qmask', train_qmask.shape)
            # train_support torch.Size([1, 1, 5, 3, 224, 224])
            # train_smasks_fg torch.Size([1, 1, 5, 1, 224, 224])
            # train_query torch.Size([1, 1, 3, 224, 224])
            # train_qmask torch.Size([1, 1, 1, 224, 224])

            # Prepare input (batch 차원 제거...)
            support_images = train_support[0].to(device)
            support_fg_mask = train_smasks_fg[0].float().to(device)
            support_bg_mask = train_smasks_bg[0].float().to(device)

            query_images = train_query[0].to(device)
            query_labels = torch.cat([query_label.long().to(device) for query_label in train_qmask[0]], dim=0)

            # Forward and Backward
            optimizer.zero_grad()
            query_pred, align_loss = model(support_images, support_fg_mask, support_bg_mask, query_images)
            query_loss = loss_fn(query_pred, query_labels)
            loss = query_loss + align_loss * 1 # _config['align_loss_scaler']
            loss.backward()
            optimizer.step()
            tensor_type_str = "<class 'torch.Tensor'>"
            if (str(type(query_loss)) == tensor_type_str) and (str(type(align_loss)) == tensor_type_str):
                train_q_loss_meter.update(query_loss.data.cpu().numpy())
                train_a_loss_meter.update(align_loss.data.cpu().numpy()) # AttributeError: 'float' object has no attribute 'data'
            else:
                print('>>> type(loss) is not tensor.')
                print('>>> query_loss: ', query_loss)
                print('>>> align_loss: ', align_loss)

            n_step = train_i + 1
            n_step_global = int((curr_epoch - 1) * len(train_loader) + n_step)
            ############################################################################################################
            #  validation
            ############################################################################################################
            if n_step in valid_step_list:
                # gt_all = torch.FloatTensor().to(device)
                # pred_all = torch.FloatTensor().to(device)
                # model.eval()
                with torch.no_grad():
                    for valid_i, (valid_support, valid_smasks_fg, valid_smasks_bg, valid_query, valid_qmask, q_img_path) in enumerate(valid_loader):
                        # Prepare input (batch 차원 제거...)
                        _support_images = valid_support[0].to(device)
                        _support_fg_mask = valid_smasks_fg[0].float().to(device)
                        _support_bg_mask = valid_smasks_bg[0].float().to(device)

                        _query_images = valid_query[0].to(device)
                        _query_labels = torch.cat([query_label.long().to(device) for query_label in valid_qmask[0]], dim=0)

                        _query_pred, _ = model(_support_images, _support_fg_mask, _support_bg_mask, _query_images)

                        _query_loss = loss_fn(_query_pred, _query_labels)


                        tensor_type_str = "<class 'torch.Tensor'>"
                        if str(type(_query_loss)) == tensor_type_str:
                            valid_q_loss_meter.update(_query_loss.data.cpu().numpy())
                        else:
                            print('>>> type(loss) is not tensor.')
                            print('>>> valid query_loss: ', _query_loss)

                        _query_pred = _query_pred.argmax(dim=1)[0].data.cpu().numpy()
                        _query_labels =  _query_labels[0].data.cpu().numpy()
                        # print('query_pred.shape', query_pred.shape)
                        # print('query_labels.shape', query_labels.shape)
                        # query_pred.shape (224, 224)
                        # query_labels.shape (224, 224)
                        # index, count = np.unique(query_pred, return_counts=True)
                        # print('count query_pred', index, count)
                        # index, count = np.unique(query_labels, return_counts=True)
                        # print('count query_labels', index, count)
                        # pred_y = torch.sigmoid(pred_y)
                        # pred_y = torch.round(pred_y)
                            
                        iou = compute_iou(_query_labels, _query_pred)
                        acc = compute_acc(_query_labels, _query_pred)
                        f1  = compute_f1( _query_labels, _query_pred)
                        valid_iou_meter.update(iou)
                        valid_acc_meter.update(acc)
                        valid_f1_meter.update(f1)

                train_loss = np.round(train_q_loss_meter.avg, 4)
                train_loss_a = np.round(train_a_loss_meter.avg, 4)
                valid_loss = np.round(valid_q_loss_meter.avg, 4)
                valid_iou = np.round(valid_iou_meter.avg, 4)
                valid_acc = np.round(valid_acc_meter.avg, 4)
                valid_f1 = np.round(valid_f1_meter.avg, 4)
                train_q_loss_meter.reset()
                train_a_loss_meter.reset()
                valid_q_loss_meter.reset()
                valid_iou_meter.reset()
                valid_acc_meter.reset()
                valid_f1_meter.reset()
                currunt_lr = 0
                for param_group in optimizer.param_groups:
                    currunt_lr = param_group['lr']
                
                logger.info(f'>> Epoch[{int(curr_epoch)}/{int(cfg["train"]["n_epoch"])}]')
                logger.info(f'>> {datetime.datetime.now()} Train_Loss(q): {train_loss} Train_Loss(a): {train_loss_a}')
                logger.info(f'>> {datetime.datetime.now()} Validation_Loss: {valid_loss} IoU:{valid_iou} acc:{valid_acc} f1:{valid_f1} Currunt LR: {currunt_lr}')
                state = {
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "epoch": curr_epoch,
                }
                save_name = cfg["model"]["arch"] + '_' + cfg["data"]["dataset"] + \
                            '.iou(' +  str(valid_iou) + ').' + \ 
                            '.acc(' +  str(valid_acc) + ').' + \ 
                            '.f1(' +  str(valid_f1) + ').' + \ 
                            '.epoch(' +  str(curr_epoch) + ').' + \ 
                            ".pth.tar" 
                # metrics.save_model_state(state, save_path=cfg["train"]["save_dir_path"], save_name=save_name)
                torch.save(state, os.path.join(cfg["train"]["save_dir_path"], save_name))

                writer.add_scalar("loss/train_loss", train_loss, n_step_global)
                writer.add_scalar("loss/tarin_loss_a", train_loss_a, n_step_global)
                writer.add_scalar("loss/valid_loss", valid_loss, n_step_global)
                writer.add_scalar("loss/valid_iou", valid_iou, n_step_global)
                writer.add_scalar("loss/valid_acc", valid_acc, n_step_global)
                writer.add_scalar("loss/valid_f1", valid_f1, n_step_global)
                writer.add_scalar("learning_rate", currunt_lr, n_step_global)
                start_ts = time.time()

        curr_epoch += 1
        scheduler.step()
示例#11
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch MTLNAS Eval")
    parser.add_argument(
        "--config-file",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--port", type=int, default=29502)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # Preparing for DDP training
    logging = args.local_rank == 0
    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = str(args.port)
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    # Adjust batch size for distributed training
    assert cfg.TRAIN.BATCH_SIZE % num_gpus == 0
    cfg.TRAIN.BATCH_SIZE = int(cfg.TRAIN.BATCH_SIZE // num_gpus)
    assert cfg.TEST.BATCH_SIZE % num_gpus == 0
    cfg.TEST.BATCH_SIZE = int(cfg.TEST.BATCH_SIZE // num_gpus)
    cfg.freeze()

    # Seeding
    random.seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # This can slow down training

    if not os.path.exists(os.path.join(cfg.SAVE_DIR,
                                       cfg.EXPERIMENT_NAME)) and logging:
        os.makedirs(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME))

    # load the data
    test_data = get_dataset(cfg, 'test')

    if distributed:
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            test_data)
    else:
        test_sampler = None

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=cfg.TEST.BATCH_SIZE,
                                              shuffle=False,
                                              sampler=test_sampler)

    task1, task2 = get_tasks(cfg)
    model = get_model(cfg, task1, task2)

    if cfg.CUDA:
        model = model.cuda()

    ckpt_path = os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                             'ckpt-%s.pth' % str(cfg.TEST.CKPT_ID).zfill(5))
    print("Evaluating Checkpoint at %s" % ckpt_path)
    ckpt = torch.load(ckpt_path)
    # compatibility with ddp saved checkpoint when evaluating without ddp
    pretrain_dict = {
        k.replace('module.', ''): v
        for k, v in ckpt['model_state_dict'].items()
    }
    model_dict = model.state_dict()
    model_dict.update(pretrain_dict)
    model.load_state_dict(model_dict)

    if distributed:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = MyDataParallel(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank,
                               find_unused_parameters=True)

    model.eval()

    task1_metric, task2_metric = evaluate(test_loader, model, task1, task2,
                                          distributed, args.local_rank)
    if logging:
        for k, v in task1_metric.items():
            print('{}: {:.9f}'.format(k, v))
        for k, v in task2_metric.items():
            print('{}: {:.9f}'.format(k, v))
示例#12
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch MTLNAS Training")
    parser.add_argument(
        "--config-file",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--port", type=int, default=29501)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # Preparing for DDP training
    logging = args.local_rank == 0
    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = str(args.port)
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    # Adjust batch size for distributed training
    assert cfg.TRAIN.BATCH_SIZE % num_gpus == 0
    cfg.TRAIN.BATCH_SIZE = int(cfg.TRAIN.BATCH_SIZE // num_gpus)
    assert cfg.TEST.BATCH_SIZE % num_gpus == 0
    cfg.TEST.BATCH_SIZE = int(cfg.TEST.BATCH_SIZE // num_gpus)
    cfg.freeze()

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d~%H:%M:%S")
    experiment_log_dir = os.path.join(cfg.LOG_DIR, cfg.EXPERIMENT_NAME,
                                      timestamp)
    if not os.path.exists(experiment_log_dir) and logging:
        os.makedirs(experiment_log_dir)
        writer = SummaryWriter(logdir=experiment_log_dir)
    printf = get_print(experiment_log_dir)
    printf("Training with Config: ")
    printf(cfg)

    # Seeding
    os.environ['PYTHONHASHSEED'] = str(cfg.SEED)
    random.seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # This can slow down training

    if not os.path.exists(os.path.join(cfg.SAVE_DIR,
                                       cfg.EXPERIMENT_NAME)) and logging:
        os.makedirs(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME))

    # load the data
    train_full_data = get_dataset(cfg, 'train')

    num_train = len(train_full_data)
    indices = list(range(num_train))
    split = int(np.floor(cfg.ARCH.TRAIN_SPLIT * num_train))

    # load the data
    if cfg.TRAIN.EVAL_CKPT:
        test_data = get_dataset(cfg, 'val')

        if distributed:
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_data)
        else:
            test_sampler = None

        test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=cfg.TEST.BATCH_SIZE,
            shuffle=False,
            sampler=test_sampler)

    task1, task2 = get_tasks(cfg)
    model = get_model(cfg, task1, task2)

    if cfg.CUDA:
        model = model.cuda()

    if distributed:
        # Important: Double check if BN is working as expected
        if cfg.TRAIN.APEX:
            printf("using apex synced BN")
            model = apex.parallel.convert_syncbn_model(model)
        else:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = MyDataParallel(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank,
                               find_unused_parameters=True)

    # hacky way to pick params
    nddr_params = []
    fc8_weights = []
    fc8_bias = []
    base_params = []
    for k, v in model.named_net_parameters():
        if 'paths' in k:
            nddr_params.append(v)
        elif model.net1.fc_id in k:
            if 'weight' in k:
                fc8_weights.append(v)
            else:
                assert 'bias' in k
                fc8_bias.append(v)
        else:
            assert 'alpha' not in k
            base_params.append(v)
    assert len(nddr_params) > 0 and len(fc8_weights) > 0 and len(fc8_bias) > 0

    parameter_dict = [{
        'params': base_params
    }, {
        'params': fc8_weights,
        'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_WEIGHT_FACTOR
    }, {
        'params': fc8_bias,
        'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_BIAS_FACTOR
    }, {
        'params': nddr_params,
        'lr': cfg.TRAIN.LR * cfg.TRAIN.NDDR_FACTOR
    }]
    optimizer = optim.SGD(parameter_dict,
                          lr=cfg.TRAIN.LR,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    if cfg.ARCH.OPTIMIZER == 'sgd':
        arch_optimizer = torch.optim.SGD(
            model.arch_parameters(),
            lr=cfg.ARCH.LR,
            momentum=cfg.TRAIN.MOMENTUM,  # TODO: separate this param
            weight_decay=cfg.ARCH.WEIGHT_DECAY)
    else:
        arch_optimizer = torch.optim.Adam(model.arch_parameters(),
                                          lr=cfg.ARCH.LR,
                                          betas=(0.5, 0.999),
                                          weight_decay=cfg.ARCH.WEIGHT_DECAY)

    if cfg.TRAIN.SCHEDULE == 'Poly':
        if cfg.TRAIN.WARMUP > 0.:
            scheduler = optim.lr_scheduler.LambdaLR(
                optimizer,
                lambda step: min(1.,
                                 float(step) / cfg.TRAIN.WARMUP) *
                (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER,
                last_epoch=-1)
            arch_scheduler = optim.lr_scheduler.LambdaLR(
                arch_optimizer,
                lambda step: min(1.,
                                 float(step) / cfg.TRAIN.WARMUP) *
                (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER,
                last_epoch=-1)
        else:
            scheduler = optim.lr_scheduler.LambdaLR(
                optimizer,
                lambda step:
                (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER,
                last_epoch=-1)
            arch_scheduler = optim.lr_scheduler.LambdaLR(
                arch_optimizer,
                lambda step:
                (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER,
                last_epoch=-1)
    elif cfg.TRAIN.SCHEDULE == 'Cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, cfg.TRAIN.STEPS)
        arch_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            arch_optimizer, cfg.TRAIN.STEPS)
    elif cfg.TRAIN.SCHEDULE == 'Step':
        milestones = (np.array([0.6, 0.9]) * cfg.TRAIN.STEPS).astype('int')
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones,
                                                   gamma=0.1)
        arch_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones,
                                                        gamma=0.1)
    else:
        raise NotImplementedError

    if cfg.TRAIN.APEX:
        model, [arch_optimizer,
                optimizer] = amp.initialize(model, [arch_optimizer, optimizer],
                                            opt_level="O1",
                                            num_losses=2)

    model.train()
    steps = 0
    while steps < cfg.TRAIN.STEPS:
        # Initialize train/val dataloader below this shuffle operation
        # to ensure both arch and weights gets to see all the data,
        # but not at the same time during mixed data training
        if cfg.ARCH.MIXED_DATA:
            np.random.shuffle(indices)

        train_data = torch.utils.data.Subset(train_full_data, indices[:split])
        val_data = torch.utils.data.Subset(train_full_data,
                                           indices[split:num_train])

        if distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_data)
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            train_sampler = None
            val_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=cfg.TRAIN.BATCH_SIZE,
            pin_memory=True,
            sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=cfg.TRAIN.BATCH_SIZE,
            pin_memory=True,
            sampler=val_sampler)

        val_iter = iter(val_loader)

        if distributed:
            train_sampler.set_epoch(steps)  # steps is used to seed RNG
            val_sampler.set_epoch(steps)

        for batch_idx, (image, label_1, label_2) in enumerate(train_loader):
            if cfg.CUDA:
                image, label_1, label_2 = image.cuda(), label_1.cuda(
                ), label_2.cuda()

            # get a random minibatch from the search queue without replacement
            val_batch = next(val_iter, None)
            if val_batch is None:  # val_iter has reached its end
                val_sampler.set_epoch(steps)
                val_iter = iter(val_loader)
                val_batch = next(val_iter)
            image_search, label_1_search, label_2_search = val_batch
            image_search = image_search.cuda()
            label_1_search, label_2_search = label_1_search.cuda(
            ), label_2_search.cuda()

            # setting flag for training arch parameters
            model.arch_train()
            assert model.arch_training
            arch_optimizer.zero_grad()
            arch_result = model.loss(image_search,
                                     (label_1_search, label_2_search))
            arch_loss = arch_result.loss

            # Mixed Precision
            if cfg.TRAIN.APEX:
                with amp.scale_loss(arch_loss, arch_optimizer,
                                    loss_id=0) as scaled_loss:
                    scaled_loss.backward()
            else:
                arch_loss.backward()

            arch_optimizer.step()
            model.arch_eval()

            assert not model.arch_training
            optimizer.zero_grad()

            result = model.loss(image, (label_1, label_2))

            out1, out2 = result.out1, result.out2
            loss1 = result.loss1
            loss2 = result.loss2

            loss = result.loss

            # Mixed Precision
            if cfg.TRAIN.APEX:
                with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()
            if cfg.ARCH.SEARCHSPACE == 'GeneralizedMTLNAS':
                model.step()  # update model temperature
            scheduler.step()
            if cfg.ARCH.OPTIMIZER == 'sgd':
                arch_scheduler.step()

            # Print out the loss periodically.
            if steps % cfg.TRAIN.LOG_INTERVAL == 0 and logging:
                printf(
                    'Train Step: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLoss1: {:.6f}\tLoss2: {:.6f}'
                    .format(steps, batch_idx * len(image),
                            len(train_loader.dataset),
                            100. * batch_idx / len(train_loader),
                            loss.data.item(), loss1.data.item(),
                            loss2.data.item()))

                # Log to tensorboard
                writer.add_scalar('lr', scheduler.get_lr()[0], steps)
                writer.add_scalar('arch_lr', arch_scheduler.get_lr()[0], steps)
                writer.add_scalar('loss/overall', loss.data.item(), steps)
                writer.add_image(
                    'image', process_image(image[0],
                                           train_full_data.image_mean), steps)
                task1.log_visualize(out1, label_1, loss1, writer, steps)
                task2.log_visualize(out2, label_2, loss2, writer, steps)

                if cfg.ARCH.ENTROPY_REGULARIZATION:
                    writer.add_scalar('loss/entropy_weight',
                                      arch_result.entropy_weight, steps)
                    writer.add_scalar('loss/entropy_loss',
                                      arch_result.entropy_loss.data.item(),
                                      steps)

                if cfg.ARCH.L1_REGULARIZATION:
                    writer.add_scalar('loss/l1_weight', arch_result.l1_weight,
                                      steps)
                    writer.add_scalar('loss/l1_loss',
                                      arch_result.l1_loss.data.item(), steps)

                if cfg.ARCH.SEARCHSPACE == 'GeneralizedMTLNAS':
                    writer.add_scalar('temperature', model.get_temperature(),
                                      steps)
                    alpha1 = torch.sigmoid(
                        model.net1_alphas).detach().cpu().numpy()
                    alpha2 = torch.sigmoid(
                        model.net2_alphas).detach().cpu().numpy()
                    alpha1_path = os.path.join(experiment_log_dir, 'alpha1')
                    if not os.path.isdir(alpha1_path):
                        os.makedirs(alpha1_path)
                    alpha2_path = os.path.join(experiment_log_dir, 'alpha2')
                    if not os.path.isdir(alpha2_path):
                        os.makedirs(alpha2_path)
                    heatmap1 = save_heatmap(
                        alpha1,
                        os.path.join(alpha1_path,
                                     "%s_alpha1.png" % str(steps).zfill(5)))
                    heatmap2 = save_heatmap(
                        alpha2,
                        os.path.join(alpha2_path,
                                     "%s_alpha2.png" % str(steps).zfill(5)))
                    writer.add_image('alpha/net1', heatmap1, steps)
                    writer.add_image('alpha/net2', heatmap2, steps)
                    network_path = os.path.join(experiment_log_dir, 'network')
                    if not os.path.isdir(network_path):
                        os.makedirs(network_path)
                    connectivity_plot = save_connectivity(
                        alpha1, alpha2, model.net1_connectivity_matrix,
                        model.net2_connectivity_matrix,
                        os.path.join(network_path,
                                     "%s_network.png" % str(steps).zfill(5)))
                    writer.add_image('network', connectivity_plot, steps)

            if steps % cfg.TRAIN.EVAL_INTERVAL == 0:
                if distributed:
                    state_dict = model.module.state_dict()
                else:
                    state_dict = model.state_dict()

                checkpoint = {
                    'cfg': cfg,
                    'step': steps,
                    'model_state_dict': state_dict,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': loss,
                    'loss1': loss1,
                    'loss2': loss2,
                    'task1_metric': None,
                    'task2_metric': None,
                }

                if cfg.TRAIN.EVAL_CKPT:
                    model.eval()
                    torch.cuda.empty_cache()  # TODO check if it helps
                    task1_metric, task2_metric = evaluate(
                        test_loader, model, task1, task2, distributed,
                        args.local_rank)

                    if logging:
                        for k, v in task1_metric.items():
                            writer.add_scalar('eval/{}'.format(k), v, steps)
                        for k, v in task2_metric.items():
                            writer.add_scalar('eval/{}'.format(k), v, steps)
                        for k, v in task1_metric.items():
                            printf('{}: {:.3f}'.format(k, v))
                        for k, v in task2_metric.items():
                            printf('{}: {:.3f}'.format(k, v))

                    checkpoint['task1_metric'] = task1_metric
                    checkpoint['task2_metric'] = task2_metric
                    model.train()
                    torch.cuda.empty_cache()  # TODO check if it helps

                if logging and steps % cfg.TRAIN.SAVE_INTERVAL == 0:
                    torch.save(
                        checkpoint,
                        os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                                     'ckpt-%s.pth' % str(steps).zfill(5)))

            if steps >= cfg.TRAIN.STEPS:
                break
            steps += 1  # train for one extra iteration to allow time for tensorboard logging..
示例#13
0
def train(args):
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  #Selects Torch Device
    split_train_val(
        args, per_val=args.per_val
    )  #Generate the train and validation sets for the model as text files:

    current_time = datetime.now().strftime(
        '%b%d_%H%M%S')  #Gets Current Time and Date
    log_dir = os.path.join(
        'runs', current_time +
        f"_{args.arch}_{args.model_name}")  #Greate the log directory
    writer = SummaryWriter(
        log_dir=log_dir)  #Initialize the tensorboard summary writer

    # Setup Augmentations
    if args.aug:  #if augmentation is true
        data_aug = Compose(
            [RandomRotate(10),
             RandomHorizontallyFlip(),
             AddNoise()])  #compose some augmentation functions
    else:
        data_aug = None

    loader = section_loader  #name the loader
    train_set = loader(
        is_transform=True, split='train', augmentations=data_aug
    )  #use custom data loader to get the training set (instance of the loader class)
    val_set = loader(
        is_transform=True,
        split='val')  #use custom made data  loader to get the validation

    n_classes = train_set.n_classes  #initalize the number of classes which is hard coded in the dataloader

    # Create sampler:

    shuffle = False  # must turn False if using a custom sampler
    with open(pjoin('data', 'splits', 'section_train.txt'), 'r') as f:
        train_list = f.read().splitlines(
        )  #load the section train list previously stored in a text file created by split_train_val() function
    with open(pjoin('data', 'splits', 'section_val.txt'), 'r') as f:
        val_list = f.read().splitlines(
        )  #load the section train list previously stored in a text file created by split_train_val() function

    class CustomSamplerTrain(torch.utils.data.Sampler
                             ):  #create a custom sampler
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x'
                    ]  #choose randomly between letter i and letter x
            self.indices = [
                idx for (idx, name) in enumerate(train_list) if char[0] in name
            ]  #choose index all inlines or all crosslines from the training list created by split_train_val() function
            return (self.indices[i] for i in torch.randperm(len(self.indices))
                    )  #shuffle the indices and return them

    class CustomSamplerVal(torch.utils.data.Sampler):
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x'
                    ]  #choose randomly between letter i and letter x
            self.indices = [
                idx for (idx, name) in enumerate(val_list) if char[0] in name
            ]  #choose index all inlines or all crosslines from the validation list created by split_train_val() function
            return (self.indices[i] for i in torch.randperm(len(self.indices))
                    )  #shuffle the indices and return them

    trainloader = data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=12, shuffle=True
    )  #use pytorch data loader to get the batches of training set
    valloader = data.DataLoader(
        val_set, batch_size=args.batch_size, num_workers=12
    )  #use pytorch data loader to get the batches of validation set

    # Setup Metrics
    running_metrics = runningScore(
        n_classes
    )  #initialize class instance for evaluation metrics for training
    running_metrics_val = runningScore(
        n_classes
    )  #initialize class instance for evaluation meterics for validation

    # Setup Model
    if args.resume is not None:  #Check if we have a stored model or not
        if os.path.isfile(args.resume):  #if yes then load the stored model
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(
                args.resume))  #if stored model requested with invalid path
    else:  #if  no stord model then load the requested model
        #n_classes=64
        model = get_model(name=args.arch,
                          pretrained=args.pretrained,
                          batch_size=args.batch_size,
                          growth_rate=32,
                          drop_rate=0,
                          n_classes=n_classes)  #get the stored model

    model = torch.nn.DataParallel(
        model, device_ids=range(
            torch.cuda.device_count()))  #Use as many GPUs as we can
    model = model.to(device)  # Send to GPU

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
            amsgrad=True,
            weight_decay=args.weight_decay,
            eps=args.eps
        )  #if no specified optimizer then load the defualt optimizer

    loss_fn = core.loss.focal_loss2d  #initialize a function loss function

    if args.class_weights:  #if class weights are to be used then intailize them
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None  #if no class weights then no need to use them

    best_iou = -100.0
    class_names = [
        'null', 'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk',
        'scruff', 'zechstein'
    ]  #initialize the name of different classes

    for arg in vars(
            args
    ):  #Before training start writting the summary of the parameters
        text = arg + ': ' + str(getattr(
            args, arg))  #get the attribute name and value, make them as string
        writer.add_text('Parameters/', text)  #store the whole string

    # training
    for epoch in range(args.n_epoch):  #for loop on the number of epochs
        # Training Mode:
        model.train()  #initialize training mode
        loss_train, total_iteration = 0, 0  # intialize training loss and total number of iterations

        for i, (images, labels) in enumerate(
                trainloader
        ):  #start the epoch then initialize the number of iterations per epoch i is the batch number
            image_original, labels_original = images, labels  #store the image and label batch in new varaibles
            images, labels = images.to(device), labels.to(
                device)  #move images and labels to the GPU

            optimizer.zero_grad()  #intialize the optimizer
            outputs = model(
                images
            )  #feed forward the images through the model (outputs is a 7 channel o/p)

            pred = outputs.detach().max(1)[1].cpu().numpy(
            )  #get the model o/p from GPU, select the index of the maximum channel and send it back to CPU
            gt = labels.detach().cpu().numpy(
            )  #get the true lablels from GPU and send them to CPU
            running_metrics.update(
                gt, pred
            )  #call the function update and pass the ground truth and the predicted classes

            loss = loss_fn(input=outputs,
                           target=labels,
                           gamma=args.gamma,
                           loss_type=args.loss_parameters
                           )  #call the loss fuction to calculate the loss
            loss_train += loss.item()  #gets the scalar value held in the loss.
            loss.backward(
            )  # Use autograd to compute the backward pass. This call will compute the gradient of loss with respect to all Tensors with requires_grad=True.

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip
                )  #The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

            optimizer.step(
            )  #step the optimizer (update the model weights with the new gradients)
            total_iteration = total_iteration + 1  #increment the total number of iterations by 1

            if (
                    i
            ) % 20 == 0:  #if 20% of the total number of iterations pass then
                print(
                    "Epoch [%d/%d] training Loss: %.4f" %
                    (epoch + 1, args.n_epoch, loss.item())
                )  #print the current epoch, total number of epochs and the current training loss

            numbers = [0, 14, 29, 49, 99]  #select some numbers
            if i in numbers:  #if the current batch number is in numbers
                # number 0 image in the batch
                tb_original_image = vutils.make_grid(
                    image_original[0][0], normalize=True, scale_each=True
                )  #select the first image in the batch create a tensorboard grid form the image tensor
                writer.add_image('train/original_image', tb_original_image,
                                 epoch + 1)  #send the image to writer

                labels_original = labels_original.numpy(
                )[0]  #convert the ground truth lablels of the first image in the batch to numpy array
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original)
                )  #Decode segmentation class labels into a color image
                writer.add_image('train/original_label',
                                 np_to_tb(correct_label_decoded),
                                 epoch + 1)  #send the image to the writer
                out = F.softmax(outputs, dim=1)  #softmax of the network o/p
                prediction = out.max(1)[1].cpu().numpy()[
                    0]  #get the index of the maximum value after softmax
                confidence = out.max(1)[0].cpu().detach()[
                    0]  # this returns the confidence in the chosen class

                tb_confidence = vutils.make_grid(
                    confidence, normalize=True, scale_each=True
                )  #convert the confidence from tensor to image

                decoded = train_set.decode_segmap(np.squeeze(
                    prediction))  #Decode predicted classes to colours
                writer.add_image(
                    'train/predicted', np_to_tb(decoded), epoch + 1
                )  #send predicted map to writer along with the epoch number
                writer.add_image(
                    'train/confidence', tb_confidence, epoch + 1
                )  #send the confidence to writer along with the epoch number

                unary = outputs.cpu().detach(
                )  #get the Nw o/p for the whole batch
                unary_max = torch.max(
                    unary)  #normalize the Nw o/p w.r.t whole batch
                unary_min = torch.min(unary)
                unary = unary.add((-1 * unary_min))
                unary = unary / (unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][
                        channel]  #get the normalized o/p for the first image in the batch
                    tb_channel = vutils.make_grid(
                        decoded_channel, normalize=True,
                        scale_each=True)  #prepare a image from tensor
                    writer.add_image(f'train_classes/_{class_names[channel]}',
                                     tb_channel,
                                     epoch + 1)  #send image to writer

        # Average metrics after finishing all batches for the whole epoch, and save in writer()
        loss_train /= total_iteration  #total loss for all iterations/ number of iterations
        score, class_iou = running_metrics.get_scores(
        )  #returns a dictionary of the calculated accuracy metrics and class iu
        writer.add_scalar(
            'train/Pixel Acc', score['Pixel Acc: '],
            epoch + 1)  # store the epoch metrics in the tensorboard writer
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch + 1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch + 1)
        confusion = score['confusion_matrix']
        writer.add_image(f'train/confusion matrix', np_to_tb(confusion),
                         epoch + 1)

        running_metrics.reset()  #resets the confusion matrix
        writer.add_scalar('train/loss', loss_train,
                          epoch + 1)  #store the training loss
        #Finished one epoch of training, starting one epoch of testing
        if args.per_val != 0:  # if validation is required
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()  #start validation mode
                loss_val, total_iteration_val = 0, 0  # initialize validation loss and total number of iterations

                for i_val, (images_val, labels_val) in tqdm(
                        enumerate(valloader)):  #start validation testing
                    image_original, labels_original = images_val, labels_val  #store original validation errors
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(
                            device)  #send validation images and labels to GPU

                    outputs_val = model(images_val)  #feedforward the image
                    pred = outputs_val.detach().max(
                        1)[1].cpu().numpy()  #get the network class prediction
                    gt = labels_val.detach().cpu().numpy(
                    )  #get the ground truth from the GPU

                    running_metrics_val.update(
                        gt, pred)  #run metrics on the validation data

                    loss = loss_fn(input=outputs_val,
                                   target=labels_val,
                                   gamma=args.gamma,
                                   loss_type=args.loss_parameters
                                   )  #calculate the loss function
                    total_iteration_val = total_iteration_val + 1  #increment the loop counter

                    if (
                            i_val
                    ) % 20 == 0:  #After 20% of batches for validation print the validation loss
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:  #select batch number 0
                        # number 0 image in the batch
                        tb_original_image = vutils.make_grid(
                            image_original[0][0],
                            normalize=True,
                            scale_each=True
                        )  #make first tensor in the batch as image
                        writer.add_image('val/original_image',
                                         tb_original_image,
                                         epoch)  #send image to writer
                        labels_original = labels_original.numpy()[
                            0]  #get origianl labels of image 0
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original)
                        )  #convert the labels to colour map
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded),
                                         epoch +
                                         1)  #send the coloured map to writer

                        out = F.softmax(
                            outputs_val,
                            dim=1)  #get soft max of the network 7 channel o/p

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy(
                        )[0]  #get the position of the max o/p across different channels
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach(
                        )[0]  #get the maximum o/p of the Nw across different channels
                        tb_confidence = vutils.make_grid(
                            confidence, normalize=True,
                            scale_each=True)  #convert tensor to image

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction)
                        )  #convert predicted classes to colour maps
                        writer.add_image('val/predicted', np_to_tb(decoded),
                                         epoch + 1)  #send prediction to writer
                        writer.add_image('val/confidence', tb_confidence,
                                         epoch + 1)  #send confidence to writer

                        unary = outputs.cpu().detach(
                        )  #get Nw o/p of the current batch
                        unary_max, unary_min = torch.max(unary), torch.min(
                            unary)  #normalize across all the Nw o/p
                        unary = unary.add((-1 * unary_min))
                        unary = unary / (unary_max - unary_min)

                        for channel in range(
                                0, len(class_names)
                        ):  #for all the 7 channels of the Nw op
                            tb_channel = vutils.make_grid(
                                unary[0][channel],
                                normalize=True,
                                scale_each=True
                            )  #convert the channel o/p of the class to image
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}',
                                tb_channel, epoch + 1)  #send image to writer
                # finished one cycle of validation after iterating over all validation batched
                score, class_iou = running_metrics_val.get_scores(
                )  #returns a dictionary of the calculated accuracy metrics and class iu
                for k, v in score.items():  #??
                    print(k, v)

                writer.add_scalar('val/Pixel Acc', score['Pixel Acc: '],
                                  epoch + 1)  #send metrics to writer
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch + 1)
                confusion = score['confusion_matrix']
                writer.add_image(f'val/confusion matrix', np_to_tb(confusion),
                                 epoch + 1)
                writer.add_scalar('val/loss', loss.item(), epoch + 1)
                running_metrics_val.reset()  #reset confusion matrix

                if score['Mean IoU: '] >= best_iou:  #compare with the validation mean iou of current epoch with the best stored validation mean IoU
                    best_iou = score[
                        'Mean IoU: ']  #if better, then store the better and store the current model as the best model
                    model_dir = os.path.join(log_dir,
                                             f"{args.arch}_model_best.pkl")
                    torch.save(model, model_dir)

                if epoch % 10 == 0:  #every 10 epochs store the current model
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_ep{epoch}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model every 10 epochs:
            if (epoch + 1) % 10 == 0:
                model_dir = os.path.join(
                    log_dir, f"{args.arch}_ep{epoch + 1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()  #close the writer
示例#14
0
cfg = get_cfg(interactive=False)

# prepare dataset
DatasetClass = get_dataset(cfg.DATASET)
dataloader_dict = dict()
for mode in cfg.MODES:
    phase_dataset = DatasetClass(cfg, mode=mode)
    dataloader_dict[mode] = DataLoader(
        phase_dataset,
        batch_size=cfg.BATCHSIZE,
        shuffle=True if mode in ['train'] else False,
        num_workers=cfg.DATALOADER_WORKERS,
        pin_memory=True,
        drop_last=True)

# prepare models
ModelClass = get_model(cfg.MODEL)
model = ModelClass(cfg)

# prepare logger
LoggerClass = get_logger(cfg.LOGGER)
logger = LoggerClass(cfg)

# register dataset, models, logger to trainer
trainer = Trainer(cfg, model, dataloader_dict, logger)

# start training
epoch_total = cfg.EPOCH_TOTAL + (cfg.RESUME_EPOCH_ID if cfg.RESUME else 0)
while trainer.do_epoch() <= cfg.EPOCH_TOTAL:
    pass
示例#15
0
文件: train.py 项目: lilujunai/MTLNAS
def main():
    parser = argparse.ArgumentParser(description="Baseline Experiment Training")
    parser.add_argument(
        "--config-file",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d~%H:%M:%S")
    experiment_log_dir = os.path.join(cfg.LOG_DIR, cfg.EXPERIMENT_NAME, timestamp)
    if not os.path.exists(experiment_log_dir):
        os.makedirs(experiment_log_dir)
    writer = SummaryWriter(logdir=experiment_log_dir)
    printf = get_print(experiment_log_dir)
    printf("Training with Config: ")
    printf(cfg)

    # Seeding
    random.seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # This can slow down training

    if not os.path.exists(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME)):
        os.makedirs(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME))

    # load the data
    train_data = get_dataset(cfg, 'train')
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, pin_memory=True)

    # load the data
    if cfg.TRAIN.EVAL_CKPT:
        test_loader = torch.utils.data.DataLoader(
            get_dataset(cfg, 'val'),
            batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, pin_memory=True)

    task1, task2 = get_tasks(cfg)
    model = get_model(cfg, task1, task2)
    
    if cfg.CUDA:
        model = model.cuda()

    # hacky way to pick params
    nddr_params = []
    fc8_weights = []
    fc8_bias = []
    base_params = []
    for k, v in model.named_parameters():
        if 'nddrs' in k:
            nddr_params.append(v)
        elif model.net1.fc_id in k:
            if 'weight' in k:
                fc8_weights.append(v)
            else:
                assert 'bias' in k
                fc8_bias.append(v)
        else:
            base_params.append(v)
    
    if not cfg.MODEL.SINGLETASK and not cfg.MODEL.SHAREDFEATURE:
        assert len(nddr_params) > 0 and len(fc8_weights) > 0 and len(fc8_bias) > 0

    parameter_dict = [
        {'params': fc8_weights, 'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_WEIGHT_FACTOR},
        {'params': fc8_bias, 'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_BIAS_FACTOR},
        {'params': nddr_params, 'lr': cfg.TRAIN.LR * cfg.TRAIN.NDDR_FACTOR}
    ]
    
    if not cfg.TRAIN.FREEZE_BASE:
        parameter_dict.append({'params': base_params})
    else:
        printf("Frozen net weights")
        
    optimizer = optim.SGD(parameter_dict, lr=cfg.TRAIN.LR, momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    if cfg.TRAIN.SCHEDULE == 'Poly':
        if cfg.TRAIN.WARMUP > 0.:
            scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                    lambda step: min(1., float(step) / cfg.TRAIN.WARMUP) * (1 - float(step) / cfg.TRAIN.STEPS) ** cfg.TRAIN.POWER,
                                                    last_epoch=-1)
        else:
            scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                    lambda step: (1 - float(step) / cfg.TRAIN.STEPS) ** cfg.TRAIN.POWER,
                                                    last_epoch=-1)
    elif cfg.TRAIN.SCHEDULE == 'Cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.TRAIN.STEPS)
    else:
        raise NotImplementedError
        
    if cfg.TRAIN.APEX:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    model.train()
    steps = 0
    while steps < cfg.TRAIN.STEPS:
        for batch_idx, (image, label_1, label_2) in enumerate(train_loader):
            if cfg.CUDA:
                image, label_1, label_2 = image.cuda(), label_1.cuda(), label_2.cuda()
            optimizer.zero_grad()

            result = model.loss(image, (label_1, label_2))
            out1, out2 = result.out1, result.out2

            loss1 = result.loss1
            loss2 = result.loss2

            loss = result.loss
            
            # Mixed Precision
            if cfg.TRAIN.APEX:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()
            model.step()  # update model step count
            scheduler.step()

            # Print out the loss periodically.
            if steps % cfg.TRAIN.LOG_INTERVAL == 0:
                printf('Train Step: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLoss1: {:.6f}\tLoss2: {:.6f}'.format(
                    steps, batch_idx * len(image), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.data.item(),
                    loss1.data.item(), loss2.data.item()))

                # Log to tensorboard
                writer.add_scalar('lr', scheduler.get_lr()[0], steps)
                writer.add_scalar('loss/overall', loss.data.item(), steps)
                task1.log_visualize(out1, label_1, loss1, writer, steps)
                task2.log_visualize(out2, label_2, loss2, writer, steps)
                writer.add_image('image', process_image(image[0], train_data.image_mean), steps)

            if steps % cfg.TRAIN.SAVE_INTERVAL == 0:
                checkpoint = {
                    'cfg': cfg,
                    'step': steps,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': loss,
                    'loss1': loss1,
                    'loss2': loss2,
                    'task1_metric': None,
                    'task2_metric': None,
                }

                if cfg.TRAIN.EVAL_CKPT:
                    model.eval()
                    task1_metric, task2_metric = evaluate(test_loader, model, task1, task2)
                    for k, v in task1_metric.items():
                        writer.add_scalar('eval/{}'.format(k), v, steps)
                    for k, v in task2_metric.items():
                        writer.add_scalar('eval/{}'.format(k), v, steps)
                    for k, v in task1_metric.items():
                        printf('{}: {:.3f}'.format(k, v))
                    for k, v in task2_metric.items():
                        printf('{}: {:.3f}'.format(k, v))

                    checkpoint['task1_metric'] = task1_metric
                    checkpoint['task2_metric'] = task2_metric
                    model.train()

                torch.save(checkpoint, os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                                                    'ckpt-%s.pth' % str(steps).zfill(5)))

            if steps >= cfg.TRAIN.STEPS:
                break
            steps += 1
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Generate the train and validation sets for the model:
    split_train_val_weak(args, per_val=args.per_val)
    loader = patch_loader_weak

    current_time = datetime.now().strftime('%b%d_%H%M%S')
    log_dir = os.path.join('runs',
                           current_time + f"_{args.arch}_{args.model_name}")
    writer = SummaryWriter(log_dir=log_dir)
    # Setup Augmentations
    if args.aug:
        data_aug = Compose(
            [RandomRotate(15),
             RandomHorizontallyFlip(),
             AddNoise()])
    else:
        data_aug = None

    train_set = loader(is_transform=True,
                       split='train',
                       augmentations=data_aug)

    # Without Augmentation:
    val_set = loader(is_transform=True,
                     split='val',
                     patch_size=args.patch_size)

    #if args.mixup:
    #    train_set1 = loader(is_transform=True,
    #                       split='train',
    #                       augmentations=data_aug)

    n_classes = train_set.n_classes

    trainloader = data.DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    #####################################################################
    #shuffle and load
    random.shuffle(train_set.patches['train'])  #shuffle list of IDs
    alpha = 0.5
    trainloader1 = data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=4,
        shuffle=True)  #load shuffeled data again in another loader
    ######################################################################

    valloader = data.DataLoader(val_set,
                                batch_size=args.batch_size,
                                num_workers=4)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        model = get_model(args.arch, args.pretrained, n_classes)

    # Use as many GPUs as we can
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model = model.to(device)  # Send to GPU

    # PYTROCH NOTE: ALWAYS CONSTRUCT OPTIMIZERS AFTER MODEL IS PUSHED TO GPU/CPU,

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.Adam(model.parameters(), amsgrad=True)

    loss_fn = core.loss.focal_loss2d

    if args.class_weights:
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0, 0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None

    best_iou = -100.0
    class_names = [
        'null', 'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk',
        'scruff', 'zechstein'
    ]

    for arg in vars(args):
        text = arg + ': ' + str(getattr(args, arg))
        writer.add_text('Parameters/', text)

    # training
    for epoch in range(args.n_epoch):
        # Training Mode:
        model.train()
        loss_train, total_iteration = 0, 0

        for (i, (images, labels, confs,
                 sims)), (i1, (images1, labels1, confs1,
                               sims1)) in zip(enumerate(trainloader),
                                              enumerate(trainloader1)):

            N, c, w, h = labels.shape
            one_hot = torch.FloatTensor(N, 7, w, h).zero_()
            labels_hot = one_hot.scatter_(
                1, labels.data,
                1)  # create one hot representation for the labels

            if args.mixup:  #if mixup is true then mix
                lam = torch.from_numpy(
                    np.random.beta(alpha, alpha,
                                   (N, 1, 1, 1))).float()  #sampling lambda
                one_hot = torch.FloatTensor(N, 7, w, h).zero_()
                labels_hot1 = one_hot.scatter_(
                    1, labels1.data,
                    1)  # create one hot representation for the labels
                images, labels, labels_hot, confs, sims = (
                    lam * images + (1 - lam) * images1), (
                        lam * labels.float() + (1 - lam) * labels1.float()), (
                            lam * labels_hot + (1 - lam) * labels_hot1), (
                                lam * confs.squeeze() +
                                (1 - lam) * confs1.squeeze()), (
                                    lam.squeeze() * sims.float() +
                                    (1 - lam).squeeze() * sims1.float()
                                )  #mixup

            image_original = images  #TODO Q: Are the passed original lables correct? in the context of following comaprison in line 233
            images, labels_hot, confs, sims = images.to(device), labels_hot.to(
                device), confs.to(device), sims.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            pred = outputs.detach().max(1)[1].cpu().numpy()
            labels_original = confs.squeeze().permute(
                0, 3, 1, 2).detach().max(1)[1].cpu().numpy()
            running_metrics.update(labels_original, pred)
            loss = loss_fn(input=outputs,
                           target=labels_hot,
                           conf=confs,
                           alpha=class_weights,
                           sim=sims,
                           gamma=args.gamma,
                           loss_type=args.loss_parameters,
                           soft_dev=args.soft_dev)
            loss_train += loss.item()
            loss.backward()

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()
            total_iteration = total_iteration + 1

            if (i) % 20 == 0:
                print("Epoch [%d/%d] training Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.item()))

            numbers = [0, 14, 29]
            if i in numbers:

                tb_original_image = vutils.make_grid(image_original[i][0],
                                                     normalize=True,
                                                     scale_each=True)
                writer.add_image('train/original_image', tb_original_image,
                                 epoch + 1)

                # tb_confs_original = vutils.make_grid(confs_tb, normalize=True, scale_each=True)
                # writer.add_image('train/confs_original',tb_confs_original, epoch +1)

                labels_original = labels_original[i]
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original))
                writer.add_image('train/original_label',
                                 np_to_tb(correct_label_decoded), epoch + 1)
                out = F.softmax(outputs, dim=1)

                # this returns the max. channel number:
                prediction = out.max(1)[1].cpu().numpy()[0]
                # this returns the confidence:
                confidence = out.max(1)[0].cpu().detach()[0]
                tb_confidence = vutils.make_grid(confidence,
                                                 normalize=True,
                                                 scale_each=True)

                decoded = train_set.decode_segmap(np.squeeze(prediction))
                writer.add_image('train/predicted', np_to_tb(decoded),
                                 epoch + 1)
                writer.add_image('train/confidence', tb_confidence, epoch + 1)

                unary = outputs.cpu().detach()
                unary_max = torch.max(unary)
                unary_min = torch.min(unary)
                unary = unary.add((-1 * unary_min))
                unary = unary / (unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][channel]
                    tb_channel = vutils.make_grid(decoded_channel,
                                                  normalize=True,
                                                  scale_each=True)
                    writer.add_image(f'train_classes/_{class_names[channel]}',
                                     tb_channel, epoch + 1)

        # Average metrics, and save in writer()
        loss_train /= total_iteration
        score, class_iou = running_metrics.get_scores()
        writer.add_scalar('train/Pixel Acc', score['Pixel Acc: '], epoch + 1)
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch + 1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch + 1)

        confusion = score['confusion_matrix']
        writer.add_image(f'train/confusion matrix', np_to_tb(confusion),
                         epoch + 1)

        running_metrics.reset()
        writer.add_scalar('train/loss', loss_train, epoch + 1)

        if args.per_val != 0:
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()
                loss_val, total_iteration_val = 0, 0

                for i_val, (images_val, labels_val, conf_val,
                            sim_val) in tqdm(enumerate(valloader)):

                    N, c, w, h = labels_val.shape
                    one_hot = torch.FloatTensor(N, 7, w, h).zero_()
                    labels_hot_val = one_hot.scatter_(
                        1, labels_val.data,
                        1)  # create one hot representation for the labels

                    image_original, labels_original = images_val, labels_val
                    images_val, labels_hot_val, conf_val, sim_val = images_val.to(
                        device), labels_hot_val.to(device), conf_val.to(
                            device), sim_val.to(device)

                    outputs_val = model(images_val)
                    pred = outputs_val.detach().max(1)[1].cpu().numpy()
                    gt = labels_val.numpy()

                    running_metrics_val.update(gt, pred)

                    loss = loss_fn(input=outputs_val,
                                   target=labels_hot_val,
                                   conf=conf_val,
                                   alpha=class_weights,
                                   sim=sim_val,
                                   gamma=args.gamma,
                                   loss_type=args.loss_parameters,
                                   soft_dev=args.soft_dev)

                    total_iteration_val = total_iteration_val + 1

                    if (i_val) % 20 == 0:
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:

                        tb_original_image = vutils.make_grid(
                            image_original[i_val][0],
                            normalize=True,
                            scale_each=True)
                        writer.add_image('val/original_image',
                                         tb_original_image, epoch)
                        labels_original = labels_original.numpy()[0]
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original))
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded),
                                         epoch + 1)

                        out = F.softmax(outputs_val, dim=1)

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy()[0]
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach()[0]
                        tb_confidence = vutils.make_grid(confidence,
                                                         normalize=True,
                                                         scale_each=True)

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction))
                        writer.add_image('val/predicted', np_to_tb(decoded),
                                         epoch + 1)
                        writer.add_image('val/confidence', tb_confidence,
                                         epoch + 1)

                        unary = outputs.cpu().detach()
                        unary_max, unary_min = torch.max(unary), torch.min(
                            unary)
                        unary = unary.add((-1 * unary_min))
                        unary = unary / (unary_max - unary_min)

                        for channel in range(0, len(class_names)):
                            tb_channel = vutils.make_grid(unary[0][channel],
                                                          normalize=True,
                                                          scale_each=True)
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}',
                                tb_channel, epoch + 1)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)

                writer.add_scalar('val/Pixel Acc', score['Pixel Acc: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch + 1)

                confusion = score['confusion_matrix']
                writer.add_image(f'val/confusion matrix', np_to_tb(confusion),
                                 epoch + 1)
                writer.add_scalar('val/loss', loss.item(), epoch + 1)
                running_metrics_val.reset()

                if score['Mean IoU: '] >= best_iou:
                    best_iou = score['Mean IoU: ']
                    model_dir = os.path.join(log_dir,
                                             f"{args.arch}_model_best.pkl")
                    torch.save(model, model_dir)

                if epoch % 10 == 0:
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_ep{epoch}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model:
            if epoch % 10 == 0:
                model_dir = os.path.join(log_dir,
                                         f"{args.arch}_ep{epoch+1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()