Пример #1
0
    def test_padding_none(self):
        tensors = [torch.randn(2 * i + 1, 3) for i in range(DATASET_SIZE)]
        tensors[5] = None
        dataset = nc.SafeDataset(DictDataset(tensors))

        for batch_size in range(1, DATASET_SIZE):
            loader = nc.SafeDataLoader(dataset,
                                       batch_size=batch_size,
                                       collate_fn=self.pad)
            not_seen = DATASET_SIZE - 1
            for batch in loader:
                self.assertEqual(batch.size(1), min(batch_size, not_seen))
                not_seen -= batch.size(1)
Пример #2
0
    def test_alright(self):
        r"""Checks that the dataset is correctly loaded when nothing is wrong"""

        tensors = [torch.randn(2, 3) for _ in range(DATASET_SIZE)]
        dataset = nc.SafeDataset(TensorDataset(tensors))

        for batch_size in range(1, DATASET_SIZE):
            loader = nc.SafeDataLoader(dataset, batch_size=batch_size)

            not_seen = DATASET_SIZE
            for batch in loader:
                self.assertEqual(batch.size(0), min(batch_size, not_seen))
                not_seen -= batch.size(0)
Пример #3
0
    def test_padding_alright(self):
        r"""Pads sequence of different-sizes tensors"""
        tensors = [torch.randn(2 * i + 1, 3) for i in range(DATASET_SIZE)]
        dataset = nc.SafeDataset(TensorDataset(tensors))

        for batch_size in range(1, DATASET_SIZE):
            loader = nc.SafeDataLoader(dataset,
                                       batch_size=batch_size,
                                       collate_fn=pad_sequence)

            not_seen = DATASET_SIZE
            for batch in loader:
                self.assertEqual(batch.size(1), min(batch_size, not_seen))
                not_seen -= batch.size(1)
Пример #4
0
    def test_custom_collate_alright(self):
        r"""Custom collate_fn when whole dataset is valid"""

        tensors = [torch.randn(2, 3) for _ in range(DATASET_SIZE)]
        dataset = nc.SafeDataset(TensorDataset(tensors))

        for batch_size in range(1, DATASET_SIZE):
            loader = nc.SafeDataLoader(dataset,
                                       batch_size=batch_size,
                                       collate_fn=torch.stack)

            not_seen = DATASET_SIZE
            for batch in loader:
                self.assertEqual(batch.size(0), min(batch_size, not_seen))
                not_seen -= batch.size(0)
Пример #5
0
    def test_none(self):
        tensors = [torch.randn(2, 3) for _ in range(DATASET_SIZE)]
        tensors[5] = None
        dataset = nc.SafeDataset(DictDataset(tensors))

        for batch_size in range(1, DATASET_SIZE):
            loader = nc.SafeDataLoader(dataset, batch_size=batch_size)

            not_seen = DATASET_SIZE - 1
            for batch in loader:
                self.assertEqual(batch['idx'].size(0),
                                 min(batch_size, not_seen))
                self.assertEqual(batch['tensor'].size(0),
                                 min(batch_size, not_seen))
                not_seen -= batch['idx'].size(0)
Пример #6
0
    def test_alright(self):
        r"""Elements of the dataset are dicts"""
        tensors = [torch.randn(2, 3) for _ in range(DATASET_SIZE)]
        dataset = nc.SafeDataset(DictDataset(tensors))

        for batch_size in range(1, DATASET_SIZE):
            loader = nc.SafeDataLoader(dataset, batch_size=batch_size)

            not_seen = DATASET_SIZE
            for batch in loader:
                self.assertEqual(batch['idx'].size(0),
                                 min(batch_size, not_seen))
                self.assertEqual(batch['tensor'].size(0),
                                 min(batch_size, not_seen))
                not_seen -= batch['idx'].size(0)
Пример #7
0
    def test_padding_none_batch_first(self):
        r"""Same as above with batch first"""
        tensors = [torch.randn(2 * i + 1, 3) for i in range(DATASET_SIZE)]
        tensors[5] = None
        dataset = nc.SafeDataset(TensorDataset(tensors))

        for batch_size in range(1, DATASET_SIZE):
            loader = nc.SafeDataLoader(dataset,
                                       batch_size=batch_size,
                                       collate_fn=partial(pad_sequence,
                                                          batch_first=True))
            not_seen = DATASET_SIZE - 1
            for batch in loader:
                self.assertEqual(batch.size(0), min(batch_size, not_seen))
                not_seen -= batch.size(0)
Пример #8
0
    def test_custom_collate_none(self):
        r"""Custom collate_fn when one sample is corrupted"""

        tensors = [torch.randn(2, 3) for _ in range(DATASET_SIZE)]
        tensors[5] = None
        dataset = nc.SafeDataset(TensorDataset(tensors))

        for batch_size in range(1, DATASET_SIZE):
            loader = nc.SafeDataLoader(dataset,
                                       batch_size=batch_size,
                                       collate_fn=torch.stack)

            not_seen = DATASET_SIZE - 1
            for batch in loader:
                self.assertEqual(batch.size(0), min(batch_size, not_seen))
                not_seen -= batch.size(0)
Пример #9
0
    def test_none(self):
        r"""Checks that `None`s in the dataset are ignored"""

        for num_nones in range(1, DATASET_SIZE):
            tensors = [torch.randn(2, 3) for _ in range(DATASET_SIZE)]
            for i in random.sample(list(range(DATASET_SIZE)), num_nones):
                tensors[i] = None

            dataset = nc.SafeDataset(TensorDataset(tensors))

            for batch_size in range(1, DATASET_SIZE):
                loader = nc.SafeDataLoader(dataset, batch_size=batch_size)

                not_seen = DATASET_SIZE - num_nones
                for batch in loader:
                    self.assertEqual(batch.size(0), min(batch_size, not_seen))
                    not_seen -= batch.size(0)
Пример #10
0
def main():

    train = pd.read_csv('../input/train.csv')

    LABELS = list(train.label.unique())
    label_idx = {label: i for i, label in enumerate(LABELS)}
    train.set_index("fname")

    train["label_idx"] = train.label.apply(lambda x: label_idx[x])

    if DEBUG:
        train = train[:500]

    skf = StratifiedKFold(n_splits=config.n_folds)

    for foldNum, (train_split,
                  val_split) in enumerate(skf.split(train, train.label_idx)):

        end = time.time()
        # split the dataset for cross-validation
        train_set = train.iloc[train_split]
        train_set = train_set.reset_index(drop=True)
        val_set = train.iloc[val_split]
        val_set = val_set.reset_index(drop=True)
        logging.info("Fold {0}, Train samples:{1}, val samples:{2}".format(
            foldNum, len(train_set), len(val_set)))

        # define train loader and val loader
        trainSet = Freesound_logmel(config=config,
                                    frame=train_set,
                                    transform=transforms.Compose([ToTensor()]),
                                    mode="train")
        train_loader = nc.SafeDataLoader(nc.SafeDataset(trainSet),
                                         batch_size=config.batch_size,
                                         shuffle=True,
                                         num_workers=0,
                                         pin_memory=True)
        valSet = Freesound_logmel(config=config,
                                  frame=val_set,
                                  transform=transforms.Compose([ToTensor()]),
                                  mode="train")
        val_loader = nc.SafeDataLoader(nc.SafeDataset(valSet),
                                       batch_size=config.batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       pin_memory=True)

        model = run_method_by_string(config.arch)(pretrained=config.pretrain)

        # define loss function (criterion) and optimizer
        if config.mixup:
            train_criterion = cross_entropy_onehot
        else:
            train_criterion = nn.CrossEntropyLoss().cuda()

        val_criterion = nn.CrossEntropyLoss().cuda()
        optimizer = optim.SGD(model.parameters(),
                              lr=config.lr,
                              momentum=config.momentum,
                              weight_decay=config.weight_decay)

        cudnn.benchmark = True

        train_on_fold(model, train_criterion, val_criterion, optimizer,
                      train_loader, val_loader, config, foldNum)

        time_on_fold = time.strftime('%Hh:%Mm:%Ss',
                                     time.gmtime(time.time() - end))
        logging.info(
            "--------------Time on fold {}: {}--------------\n".format(
                foldNum, time_on_fold))
Пример #11
0
    if USE_ADVERSARIAL:
        d_model_path = sorted(glob(D_MODEL_PATH, recursive=True))
        if len(d_model_path) == 0:
            LOAD_D = False
        else:
            d_model_path = d_model_path[-1]

    os.makedirs(output_dir, exist_ok=True)

    dataset = nc.SafeDataset(
        MidiDataset(npy_glob_pattern=DATASET_PATH,
                    num_samples=num_samples,
                    train_step_multiplier=100))
    dataloader = nc.SafeDataLoader(dataset=dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=4)

    writer = SummaryWriter('logs')

    G = M.VAE(num_samples, height, width, h_dim, z_dim)

    if USE_ADVERSARIAL:
        D = M.Discriminator(num_samples, height, width)

    if LOAD_G:
        G.load_state_dict(torch.load(g_model_path))
        print(f"Loaded model {g_model_path}, epoch={start_epoch}")
    else:
        print("Training generator(VAE) from scratch")
def train(content_set,
          style_set,
          batch_size=1,
          learning_rate=0.005,
          num_epochs=1000):
    iteration = 20000
    betas = {
        'conv1_1': 1.,
        'conv2_1': 0.8,
        'conv3_1': 0.3,
        'conv4_1': 0.25,
        'conv5_1': 0.2
    }

    alpha = 1
    beta = 1500000

    np.random.seed(1000)

    content_img = content_set[1600][0].unsqueeze(0).to("cpu").detach()
    transformed_img = content_img.clone().requires_grad_(True)

    style_set = nc.SafeDataset(style_set)
    style_loader = nc.SafeDataLoader(style_set,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1)

    optimizer = optim.Adam([transformed_img], lr=learning_rate)
    train_loss = []
    for i in range(num_epochs):
        content_batch = content_img
        content_rep = get_content_rep(content_batch)
        for style_steps, (style_imgs, __) in enumerate(style_loader):
            target_reps = get_target_rep(transformed_img)
            s_loss = 0
            c_loss = get_content_loss(content_rep['conv4_2'],
                                      target_reps['conv4_2'])
            style_rep = get_style_rep(style_imgs)
            style_gs = {
                layer: gram_matrix(style_rep[layer])
                for layer in style_rep
            }
            for layer in betas:
                target_rep = target_reps[layer]
                target_g = gram_matrix(target_rep)
                style_g = style_gs[layer]
                s_loss_temp = betas[layer] * get_style_loss(style_g, target_g)
                s_loss += s_loss_temp

            loss = alpha * c_loss + beta * s_loss
            loss.backward(retain_graph=True)
            optimizer.step()
            optimizer.zero_grad()
        if i % 50 == 0:
            print('EPOCH:{} current loss:{} '.format(i, loss.item()))
            img = content_img.to("cpu").detach()
            img = img[0].numpy()
            img = img.transpose(1, 2, 0)
            img = img * np.array((0.229, 0.224, 0.225)) + np.array(
                (0.485, 0.456, 0.406))
            img = img.clip(0, 1)
            plt.imshow(img)
            plt.show()
            img = transformed_img.to("cpu").detach()
            img = img[0].numpy()
            img = img.transpose(1, 2, 0)
            img = img * np.array((0.229, 0.224, 0.225)) + np.array(
                (0.485, 0.456, 0.406))
            img = img.clip(0, 1)
            plt.imshow(img)
            plt.show()

        if i % 100 == 0:
            train_loss.append(loss.item())

    n = len(train_loss)
    plt.title("Train Loss")
    plt.plot(range(1, n + 1), train_loss, label="Train")
    plt.xlabel("Interation")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.show()
def train(transform_model,
          content_set,
          train_indices,
          val_indices,
          style_set,
          batch_size=1,
          learning_rate=0.005,
          num_epochs=1000,
          save_name='model_graphite'):
    iteration = 20000
    betas = {
        'conv1_1': 1.,
        'conv2_1': 0.8,
        'conv3_1': 0.3,
        'conv4_1': 0.25,
        'conv5_1': 0.2
    }

    alpha = 1
    beta = 1500000

    np.random.seed(1000)

    train_sampler = SubsetRandomSampler(train_indices)
    content_loader = torch.utils.data.DataLoader(content_set,
                                                 batch_size=batch_size,
                                                 num_workers=1,
                                                 sampler=train_sampler)

    val_sampler = SubsetRandomSampler(val_indices)
    content_val_loader = torch.utils.data.DataLoader(content_set,
                                                     batch_size=batch_size,
                                                     num_workers=1,
                                                     sampler=val_sampler)

    style_set = nc.SafeDataset(style_set)
    style_loader = nc.SafeDataLoader(style_set,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1)

    optimizer = optim.Adam(transform_model.parameters(), lr=learning_rate)
    transform_model.train()
    train_loss = []
    for i in range(num_epochs):
        for j, (content_img, __) in enumerate(content_loader):
            content_batch = content_img
            target_batch = content_batch.clone().requires_grad_(True)
            content_rep = get_content_rep(content_batch)
            for style_steps, (style_imgs, __) in enumerate(style_loader):
                transformed_img = transform_model(target_batch.cuda())
                target_reps = get_target_rep(transformed_img)
                s_loss = 0
                c_loss = get_content_loss(content_rep['conv4_2'],
                                          target_reps['conv4_2'])
                style_rep = get_style_rep(style_imgs)
                style_gs = {
                    layer: gram_matrix(style_rep[layer])
                    for layer in style_rep
                }
                for layer in betas:
                    target_rep = target_reps[layer]
                    target_g = gram_matrix(target_rep)
                    style_g = style_gs[layer]
                    s_loss_temp = betas[layer] * get_style_loss(
                        style_g, target_g)
                    s_loss += s_loss_temp

                loss = alpha * c_loss + beta * s_loss
                loss.backward(retain_graph=True)
                optimizer.step()
                optimizer.zero_grad()
            if j % 50 == 0:
                print('EPOCH:{} current loss:{} '.format(i, loss.item()))
                torch.save(
                    transform_model.state_dict(),
                    '/content/gdrive/My Drive/APS360_Style_Transfer/Saved Models/'
                    + save_name)
                img = content_img.to("cpu").detach()
                img = img[0].numpy()
                img = img.transpose(1, 2, 0)
                img = img * np.array((0.229, 0.224, 0.225)) + np.array(
                    (0.485, 0.456, 0.406))
                img = img.clip(0, 1)
                plt.imshow(img)
                plt.show()
                img = transformed_img.to("cpu").detach()
                img = img[0].numpy()
                img = img.transpose(1, 2, 0)
                img = img * np.array((0.229, 0.224, 0.225)) + np.array(
                    (0.485, 0.456, 0.406))
                img = img.clip(0, 1)
                plt.imshow(img)
                plt.show()

            if j % 100 == 0:
                train_loss.append(loss.item())

    n = len(train_loss)
    plt.title("Train Loss")
    plt.plot(range(1, n + 1), train_loss, label="Train")
    plt.xlabel("Interation")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.show()
Пример #14
0
def main():

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0')

    configuration = Configuration('project.config')
    config = configuration.get_config_options()

    # create_vocabulary_from_dataset(config)
    # make_target_vocab(config)

    word2idx, dataset_vectors = load_target_vocab(config)

    use_safe_dataset = True
    dataset = TextDataset(word2idx, dataset_vectors, config=config)
    if use_safe_dataset:
        dataset = nc.SafeDataset(dataset)
        data_loader = nc.SafeDataLoader(dataset=dataset,
                                        batch_size=config.globals.BATCH_SIZE,
                                        num_workers=0,
                                        shuffle=True)
    else:
        data_loader = datautil.DataLoader(dataset=dataset,
                                          batch_size=config.globals.BATCH_SIZE,
                                          num_workers=0,
                                          shuffle=False)

    # model = SentenceEncoder(target_vocab = word2idx.keys(), vectors = dataset_vectors, config = config)
    # doc_enc = DocumentEncoder(config=config)

    policy_net = DQN(target_vocab=word2idx.keys(),
                     vectors=dataset_vectors,
                     config=config).to(device)
    target_net = DQN(target_vocab=word2idx.keys(),
                     vectors=dataset_vectors,
                     config=config).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    reward_func = Reward()
    h = reward_func.get_reward([['hello']], [[['this is a good hello']]])
    print(h)

    def select_action(config, doc, state):
        # TODO: fix the function to handle the full batchsize
        # TODO: send all tensors to GPU

        sample = np.random.random()

        # article = '\n'.join(doc['raw'])
        # article = article.split('\n\n')
        doc_tensor = doc['tensor'][:, :len(doc['raw']) - 1]
        # Putting this here as we need q_values one way or the other
        q_values = policy_net(doc_tensor,
                              get_q_approx=True,
                              sum_i=state['sum_i'])

        # Decay the epsilon per EPS_DECAY_ITER iterations
        if iter % config.dqn.EPS_DECAY_ITER == 0:
            config.dqn.EPS_START -= config.dqn.EPS_DECAY
            print('EPSILON Decayed to : ', config.dqn.EPS_START)

        if sample < config.dqn.EPS_START:
            i = np.random.randint(low=0, high=len(doc['raw']) - 1)
        else:
            # actions are sentences
            i = torch.argmax(q_values, dim=1)

        a_i = (i, doc['raw'][i])
        return a_i, q_values

    optimizer = torch.optim.RMSprop(policy_net.parameters())
    memory = ReplayMemory(config.dqn.REPLAY_MEM_SIZE)

    epoch = 0
    iter = 0

    for epoch in tqdm(range(epoch, config.globals.NUM_EPOCHS)):
        policy_net.train()
        for i, (story, highlights) in tqdm(enumerate(data_loader)):
            state = {
                'curr_summary_ids': [],
                'curr_summary': [],
                'sum_i': torch.zeros((100))
            }
            next_state = state
            prev_r_i, r_i = 0
            # locking to 10 for simplicity purposes
            for i in count(config.dqn.SUMMARY_LENGTH):
                iter = iter + 1

                # if i>20 : break

                story['tensor'] = story['tensor'].to(device)
                highlights['tensor'] = highlights['tensor'].to(device)
                # sentence representation are calculated as no grad because we dont want to disturb / update the weights
                with torch.no_grad():
                    H_i, D_i, x = policy_net(
                        story['tensor'][:, :len(story['raw']) - 1])

                a_i, q_values = select_action(config, story, state)

                next_state['curr_summary_ids'].append(int(a_i[0]))
                next_state['curr_summary'].append(a_i[1])
                next_state['sum_i'] = Sum_i(H_i, state['curr_summary_ids'],
                                            q_values)
                r_i = reward_func.get_reward([next_state['curr_summary']],
                                             gold_summ=[[highlights['raw']]],
                                             **{
                                                 'prev_score': prev_r_i,
                                                 'config': config
                                             })
                prev_r_i = r_i
                # checks if we are close to the summ length part
                done = check_done(config, next_state)
                if done:
                    next_state = None
                # TODO: check which a_i has to be loaded , a_i[0] or a_i[1] or just a_i
                memory.push(state, H_i[a_i[0]], next_state, r_i)
                state = next_state
                optimize_model(config)

                if done:
                    break
Пример #15
0
from torchvision import transforms
from torch.utils.data import DataLoader as dataloader
import nonechucks as nc
from voc_seg import my_data, label_acc_score, voc_colormap, seg_target
vgg = tv.models.vgg19_bn(pretrained=True)

image_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
train_data = my_data((240, 320), transform=image_transform)
test_data = my_data((240, 320), image_set='val', transform=image_transform)
trainset = nc.SafeDataset(train_data)
testset = nc.SafeDataset(test_data)
# trainload=nc.SafeDataLoader(trainset,batch_size=8)
testload = nc.SafeDataLoader(testset, batch_size=8)

mask_transform = transforms.Compose(
    [seg_target()])  # to_tensor will make it from nhwc to nchw

train_voc = tv.datasets.VOCSegmentation(
    '/home/llm/PycharmProjects/seg_1224/data/',
    image_set='train',
    transform=image_transform,
    target_transform=mask_transform)
trainload = torch.utils.data.DataLoader(train_voc, shuffle=True)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device=torch.device('cpu')
dtype = torch.float32
num_class = 21
Пример #16
0
len(train_ds)

"""We use about 20% of the original training set for our validation set, which uses just 1500 images."""

img, label = val_ds[0]
print(img.shape, label)
len(val_ds)

"""We use SafeDataset to clean out any images that have broken links/won't work."""

training_set = nc.SafeDataset(train_ds)
validation_set = nc.SafeDataset(val_ds)

"""Likewise we use SafeDataLoader to load in our batch size for training the model (about half the dimensions of our images), and then we load in our validation set as well. We decided to choose a batch size about half the dimensions of our input images, while also doubling its size for the validation set."""

train_dl = nc.SafeDataLoader(training_set, batch_size = 64, shuffle = True)
val_dl = nc.SafeDataLoader(validation_set, batch_size = 128, shuffle = False)

"""Now we create our CNN model called RussianArtClassifier. We create several layers in the model called features that we use by creating nn.Sequential objects. We first use a Conv2d method that creates a set of convolutional filters that use the first argument as the number of input channels, which for us is 3 since we are using color images. The second argument is the number of output channels, which for us is 32 channels, and then the kernel_size argument asks for how large our convolutional filter to be, which for us will be a 3x3 size, and then finally stride which controls how far the kernel moves on the input image (we also have padding but we kept it the same at 1 throughout the CNN). The output of a convolutional layer is given by (W−F+2P)/S+1, where W represents the weights (128x128x3), F is the kernel (3x3), P is pooling, and S is stride. We have 4 convolutional layers with ReLU activation function after each layer along with a batch normalization. The last layer is a max pooling operation, used to reduce the number of parameters to learn and computation needed to be performed, and we have it set up so that we down-sample our data by reducing the effective size of it by a factor of 2. All of the self.features of the convolutional layers are similar to one another, and we also include a drop-out layer to avoid over-fitting the model. We also have a fully connected layer called classifier which also uses ReLU activations. We have it all compiled in the forward function, which takes input argument x (which is the data being passed through the model), and we pass this data to all the convolutional layers and return the output as "out" and we also apply a view function after these layers are done that flattens out the data dimensions. Then the dropout is applied, followed by the fully connected layers, with the final output being returned from the function. It's also fed through a soft max function that converts our single vector of numbers into a vector of probabilities, 1-5 representing each type of art piece that can be chosen by the model. We also included a way to store the results for training loss, validation loss, and the total accuracies of the model, along with a function to get the accuracy by using the torch.max function which returns the index of the maximum value in a sensor, which we used from code we used in a previous homework called Classify Impressionists."""

class RussianArtClassifier(nn.Module):
    def __init__(self):
        super(RussianArtClassifier, self).__init__()
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),

            nn.MaxPool2d(kernel_size = 2, stride = 2))

        self.features2 = nn.Sequential( 
Пример #17
0
    import nonechucks as nc
    dataset = Charades(
        root='/vision/group/Charades_RGB/Charades_v1_rgb',
        split='train',
        labelpath='/vision/group/Charades/annotations/Charades_v1_train.csv',
        cachedir=
        '/vision2/u/rhsieh91/pytorch-i3d/charades_experiments/charades_cache',
        clip_size=16,
        is_val=False,
        transform=transform)
    dataset = nc.SafeDataset(dataset)

    # train_loader_1 = torch.utils.data.DataLoader(dataset,
    #                                            batch_size=8,
    #                                            shuffle=True,
    #                                            num_workers=0,
    #                                            pin_memory=True)
    train_loader_2 = nc.SafeDataLoader(dataset,
                                       batch_size=8,
                                       shuffle=True,
                                       num_workers=0,
                                       pin_memory=True)
    # pdb.set_trace()

    for i, a in enumerate(train_loader):
        print(a[0].shape)  # data
        print(a[1])  # action
        print(a[2])  # clip id
        if i == 10:
            break
Пример #18
0
import random
from voc_seg import my_data,voc_colormap
import torchvision.models as model
import torchvision.transforms.functional as TF
from PIL import Image as image
import matplotlib.pyplot as plt
import nonechucks as nc
# from  matplotlib import pyplot as plt
image_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)),
                                    ])
mask_transform=transforms.Compose([transforms.ToTensor()])# to_tensor will make it from nhwc to nchw
trainset=my_data((320,224),'data',transform=image_transform)
trainset=nc.SafeDataset(trainset)
testset=my_data((96,224),'data',transform=image_transform)
# loader=dataloader(trainset,batch_size=32,shuffle=True)
loader=nc.SafeDataLoader(trainset,batch_size=4)
# test_loader=dataloader(testset,batch_size=4)
vgg=model.vgg16(pretrained=True)
device=torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
device=torch.device('cpu')
dtype = torch.float32
#%%
a,b=next(iter(loader))
class_num=21

#%%
def bilinear_kernel( in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else: