예제 #1
0
    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
        super(conv2d_down_block, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.BatchNorm2d(out_size),
                                     nn.ReLU(inplace=True), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        else:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.ReLU(inplace=True), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
예제 #2
0
    def __init__(self,
                 in_channels=3,
                 n_classes=1,
                 feature_scale=4,
                 is_deconv=True,
                 is_batchnorm=True):
        super(UNet, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale
        #
        # filters = [32, 64, 128, 256, 512]
        filters = [64, 128, 256, 512, 1024]
        # # filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = conv2d_down_block(self.in_channels, filters[0],
                                       self.is_batchnorm)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = conv2d_down_block(filters[0], filters[1],
                                       self.is_batchnorm)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = conv2d_down_block(filters[1], filters[2],
                                       self.is_batchnorm)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = conv2d_down_block(filters[2], filters[3],
                                       self.is_batchnorm)
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.center = conv2d_down_block(filters[3], filters[4],
                                        self.is_batchnorm)

        # upsampling
        self.up_concat4 = conv2d_up_block(filters[4], filters[3],
                                          self.is_deconv)
        self.up_concat3 = conv2d_up_block(filters[3], filters[2],
                                          self.is_deconv)
        self.up_concat2 = conv2d_up_block(filters[2], filters[1],
                                          self.is_deconv)
        self.up_concat1 = conv2d_up_block(filters[1], filters[0],
                                          self.is_deconv)
        #
        self.outconv1 = nn.Conv2d(filters[0], 1, 3, padding=1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')
예제 #3
0
    def __init__(self, in_size, out_size, is_deconv, n_concat=2):
        super(conv2d_up_block, self).__init__()
        self.conv = conv2d_down_block(out_size * 2, out_size, False)
        if is_deconv:
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('conv2d_down_block') != -1: continue
            init_weights(m, init_type='kaiming')
예제 #4
0
    def __init__(self, ndf=32):
        def conv_block(in_channels, out_channels):
            block = [
                nn.Conv2d(in_channels, out_channels, 3, 1, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, True),
                nn.Conv2d(out_channels, out_channels, 3, 2, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, True),
            ]
            return block

        super(Discriminator,
              self).__init__(*conv_block(3, ndf), *conv_block(ndf, ndf * 2),
                             *conv_block(ndf * 2, ndf * 4),
                             *conv_block(ndf * 4, ndf * 8),
                             *conv_block(ndf * 8, ndf * 16),
                             nn.Conv2d(ndf * 16, 1024, kernel_size=1),
                             nn.LeakyReLU(0.2),
                             nn.Conv2d(1024, 1, kernel_size=1), nn.Sigmoid())

        models.init_weights(self, init_type='normal', init_gain=0.02)
예제 #5
0
    def __init__(self, d_state, d_action, n_layers, n_units, activation):
        super().__init__()

        assert n_layers >= 1

        layers = [nn.Linear(d_state, n_units), get_activation(activation)]
        for _ in range(1, n_layers):
            layers += [nn.Linear(n_units, n_units), get_activation(activation)]
        layers += [nn.Linear(n_units, d_action)]

        [
            init_weights(layer) for layer in layers
            if isinstance(layer, nn.Linear)
        ]

        self.layers = nn.Sequential(*layers)
예제 #6
0
    def __init__(self,
                 in_channels=3,
                 n_classes=1,
                 feature_scale=4,
                 is_deconv=True,
                 is_batchnorm=True,
                 is_ds=True):
        super(UNet_2Plus, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.is_ds = is_ds
        self.feature_scale = feature_scale

        # filters = [32, 64, 128, 256, 512]
        filters = [64, 128, 256, 512, 1024]
        # filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv00 = conv2d_down_block(self.in_channels, filters[0],
                                        self.is_batchnorm)
        self.maxpool0 = nn.MaxPool2d(kernel_size=2)
        self.conv10 = conv2d_down_block(filters[0], filters[1],
                                        self.is_batchnorm)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.conv20 = conv2d_down_block(filters[1], filters[2],
                                        self.is_batchnorm)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.conv30 = conv2d_down_block(filters[2], filters[3],
                                        self.is_batchnorm)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)
        self.conv40 = conv2d_down_block(filters[3], filters[4],
                                        self.is_batchnorm)

        # upsampling
        self.up_concat01 = unetUp_origin(filters[1], filters[0],
                                         self.is_deconv)
        self.up_concat11 = unetUp_origin(filters[2], filters[1],
                                         self.is_deconv)
        self.up_concat21 = unetUp_origin(filters[3], filters[2],
                                         self.is_deconv)
        self.up_concat31 = unetUp_origin(filters[4], filters[3],
                                         self.is_deconv)

        self.up_concat02 = unetUp_origin(filters[1], filters[0],
                                         self.is_deconv, 3)
        self.up_concat12 = unetUp_origin(filters[2], filters[1],
                                         self.is_deconv, 3)
        self.up_concat22 = unetUp_origin(filters[3], filters[2],
                                         self.is_deconv, 3)

        self.up_concat03 = unetUp_origin(filters[1], filters[0],
                                         self.is_deconv, 4)
        self.up_concat13 = unetUp_origin(filters[2], filters[1],
                                         self.is_deconv, 4)

        self.up_concat04 = unetUp_origin(filters[1], filters[0],
                                         self.is_deconv, 5)

        # final conv (without any concat)
        self.final_1 = nn.Conv2d(filters[0], n_classes, 1)
        self.final_2 = nn.Conv2d(filters[0], n_classes, 1)
        self.final_3 = nn.Conv2d(filters[0], n_classes, 1)
        self.final_4 = nn.Conv2d(filters[0], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')
def train_one_folder(opt, folder):
    # Use specific GPU
    device = torch.device(opt.gpu_num)

    opt.folder = folder

    # Dataloaders
    train_dataset_file_path = os.path.join('../dataset', opt.source_domain,
                                           str(opt.folder), 'train.csv')
    train_loader = get_dataloader(train_dataset_file_path, 'train', opt)

    test_dataset_file_path = os.path.join('../dataset', opt.source_domain,
                                          str(opt.folder), 'test.csv')
    test_loader = get_dataloader(test_dataset_file_path, 'test', opt)

    # Model, optimizer and loss function
    emotion_recognizer = models.Model(opt)
    models.init_weights(emotion_recognizer)
    for param in emotion_recognizer.parameters():
        param.requires_grad = True
    emotion_recognizer.to(device)

    optimizer = torch.optim.Adam(emotion_recognizer.parameters(),
                                 lr=opt.learning_rate)
    lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                             patience=1)

    criterion = torch.nn.CrossEntropyLoss()

    best_acc = 0.
    best_uar = 0.
    es = EarlyStopping(patience=opt.patience)

    # Train and validate
    for epoch in range(opt.epochs_num):
        if opt.verbose:
            print('epoch: {}/{}'.format(epoch + 1, opt.epochs_num))

        train_loss, train_acc = train(train_loader, emotion_recognizer,
                                      optimizer, criterion, device, opt)
        test_loss, test_acc, test_uar = test(test_loader, emotion_recognizer,
                                             criterion, device, opt)

        if opt.verbose:
            print('train_loss: {0:.5f}'.format(train_loss),
                  'train_acc: {0:.3f}'.format(train_acc),
                  'test_loss: {0:.5f}'.format(test_loss),
                  'test_acc: {0:.3f}'.format(test_acc),
                  'test_uar: {0:.3f}'.format(test_uar))

        lr_schedule.step(test_loss)

        os.makedirs(os.path.join(opt.logger_path, opt.source_domain),
                    exist_ok=True)

        model_file_name = os.path.join(opt.logger_path, opt.source_domain,
                                       'checkpoint.pth.tar')
        state = {
            'epoch': epoch + 1,
            'emotion_recognizer': emotion_recognizer.state_dict(),
            'opt': opt
        }
        torch.save(state, model_file_name)

        if test_acc > best_acc:
            model_file_name = os.path.join(opt.logger_path, opt.source_domain,
                                           'model.pth.tar')
            torch.save(state, model_file_name)

            best_acc = test_acc

        if test_uar > best_uar:
            best_uar = test_uar

        if es.step(test_loss):
            break

    return best_acc, best_uar
예제 #8
0
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    #   - For 1 class and background, use n_classes=1
    #   - For 2 classes, use n_classes=1
    #   - For N > 2 classes, use n_classes=N
    net = ChooseModel(network)(n_channels=3,
                               n_classes=conf['DATASET']['NUM_CLASSES'])
    assert net is not None, f'check your argument --network'

    logging.info(
        f'Network:\n'
        f'\t{net.n_channels} input channels\n'
        f'\t{net.n_classes} output channels (classes)\n'
        f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling\n'
        f'\tApex is {"using" if args.use_apex == "True" else "not using"}')
    init_weights(net, args.init_type)
    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  img_scale=args.scale,
def main():
    train = True
    input_size = 768
    # input_size = 256  # set to none for default cropping
    print("Training with Places365 Dataset")
    max_its = 300000
    max_eps = 20000
    optimizer = 'adam'  # separate optimizers for discriminator and autoencoder
    lr = 0.0002
    batch_size = 1
    step_lr_gamma = 0.1
    step_lr_step = 200000
    discr_success_rate = 0.8
    win_rate = 0.8
    log_interval = int(max_its // 100)
    # log_interval = 100
    if log_interval < 10:
        print("\n WARNING: VERY SMALL LOG INTERVAL\n")

    lam = 0.001
    disc_wt = 1.
    trans_wt = 100.
    style_wt = 100.

    alpha = 0.05

    tblock_kernel = 10
    # Models
    encoder = models.Encoder()
    decoder = models.Decoder()
    tblock = models.TransformerBlock(kernel_size=tblock_kernel)
    discrim = models.Discriminator()

    # init weights
    models.init_weights(encoder)
    models.init_weights(decoder)
    models.init_weights(tblock)
    models.init_weights(discrim)

    if torch.cuda.is_available():
        encoder = encoder.cuda()
        decoder = decoder.cuda()
        tblock = tblock.cuda()
        discrim = discrim.cuda()

    if train:
        # load tmp weights
        if os.path.exists('tmp'):
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            encoder = torch.load("tmp/encoder.pt", map_location=device)
            decoder = torch.load("tmp/decoder.pt", map_location=device)
            tblock = torch.load("tmp/tblock.pt", map_location=device)
            discrim = torch.load("tmp/discriminator.pt", map_location=device)

        # Losses
        gen_loss = losses.SoftmaxLoss()
        disc_loss = losses.SoftmaxLoss()
        transf_loss = losses.TransformedLoss()
        style_aware_loss = losses.StyleAwareContentLoss()

        # # optimizer for encoder/decoder (and tblock? - think it has no parameters though)
        # params_to_update = []
        # for m in [encoder, decoder, tblock, discrim]:
        #     for param in m.parameters():
        #         param.requires_grad = True
        #         params_to_update.append(param)
        # # optimizer = torch.optim.Adam(params_to_update, lr=lr)

        data_dir = '../Datasets/WikiArt-Sorted/data/vincent-van-gogh_road-with-cypresses-1890'
        # data_dir = '../Datasets/WikiArt-Sorted/data/edvard-munch/'
        style_data = datasets.StyleDataset(data_dir)
        num_workers = 8
        # if mpii:
        #     dataloaders = {'train': DataLoader(datasets.MpiiDataset(train=True, input_size=input_size,
        #                                                             style_dataset=style_data, crop_size=crop_size),
        #                                        batch_size=batch_size, shuffle=True, num_workers=num_workers),
        #                    'test': DataLoader(datasets.MpiiDataset(train=False, style_dataset=style_data, input_size=input_size),
        #                                       batch_size=1, shuffle=False, num_workers=num_workers)}
        # else:
        dataloaders = {
            'train':
            DataLoader(datasets.PlacesDataset(train=True,
                                              input_size=input_size,
                                              style_dataset=style_data),
                       batch_size=batch_size,
                       shuffle=True,
                       num_workers=num_workers),
            'test':
            DataLoader(datasets.TestDataset(),
                       batch_size=1,
                       shuffle=False,
                       num_workers=num_workers)
        }

        # optimizer for encoder/decoder (and tblock? - think it has no parameters though)
        gen_params = []
        for m in [encoder, decoder]:
            for param in m.parameters():
                param.requires_grad = True
                gen_params.append(param)
        g_optimizer = torch.optim.Adam(gen_params, lr=lr)

        # optimizer for disciminator
        disc_params = []
        for param in discrim.parameters():
            param.requires_grad = True
            disc_params.append(param)
        d_optimizer = torch.optim.Adam(disc_params, lr=lr)

        scheduler_g = torch.optim.lr_scheduler.StepLR(g_optimizer,
                                                      step_lr_step,
                                                      gamma=step_lr_gamma,
                                                      last_epoch=-1)
        scheduler_d = torch.optim.lr_scheduler.StepLR(d_optimizer,
                                                      step_lr_step,
                                                      gamma=step_lr_gamma,
                                                      last_epoch=-1)

        its = 0
        print('Begin Training:')
        g_steps = 0
        d_steps = 0
        image_id = 0
        time_per_it = []
        if max_its is None:
            max_its = len(dataloaders['train'])

        # set models to train()
        encoder.train()
        decoder.train()
        tblock.train()
        discrim.train()

        d_loss = 0
        g_loss = 0
        gen_acc = 0
        d_acc = 0
        for epoch in range(max_eps):
            if its > max_its:
                break
            for images, style_images in dataloaders['train']:
                t0 = process_time()
                # utils.export_image(images[0, :, :, :], style_images[0, :, :, :], 'input_images.jpg')

                # zero gradients
                g_optimizer.zero_grad()
                d_optimizer.zero_grad()

                if its > max_its:
                    break

                if torch.cuda.is_available():
                    images = images.cuda()
                    if style_images is not None:
                        style_images = style_images.cuda()

                # autoencoder
                emb = encoder(images)
                stylized_im = decoder(emb)

                # if training do losses etc
                stylized_emb = encoder(stylized_im)
                # add losses

                # tblock
                transformed_inputs, transformed_outputs = tblock(
                    images, stylized_im)
                # add loss

                # # # GENERATOR TRAIN # # # #
                g_optimizer.zero_grad()
                d_out_fake = discrim(
                    stylized_im
                )  # keep attached to generator because grads needed

                # accuracy given the fake output, generator images
                gen_acc = utils.accuracy(
                    d_out_fake,
                    target_label=1)  # accuracy given only the output image

                del g_loss
                g_loss = disc_wt * gen_loss(d_out_fake, target_label=1)
                g_loss += trans_wt * transf_loss(transformed_inputs,
                                                 transformed_outputs)
                g_loss += style_wt * style_aware_loss(emb, stylized_emb)
                g_loss.backward()
                d_optimizer.step()
                discr_success_rate = discr_success_rate * (
                    1. - alpha) + alpha * (1. - gen_acc)
                g_steps += 1

                # # # DISCRIMINATOR TRAIN # # # #
                d_optimizer.zero_grad()
                # detach from generator, so not propagating unnecessary gradients
                d_out_fake = discrim(stylized_im.clone().detach())
                d_out_real_ph = discrim(images)
                d_out_real_style = discrim(style_images)

                # accuracy given all the images
                d_acc_real_ph = utils.accuracy(d_out_real_ph, target_label=0)
                d_acc_fake_style = utils.accuracy(d_out_fake, target_label=0)
                d_acc_real_style = utils.accuracy(d_out_real_style,
                                                  target_label=1)
                gen_acc = 1 - d_acc_fake_style
                d_acc = (d_acc_real_ph + d_acc_fake_style +
                         d_acc_real_style) / 3

                # Loss calculation
                d_loss = disc_loss(d_out_fake, target_label=0)
                d_loss += disc_loss(d_out_real_style, target_label=1)
                d_loss += disc_loss(d_out_real_ph, target_label=0)

                d_loss.backward()
                d_optimizer.step()
                discr_success_rate = discr_success_rate * (
                    1. - alpha) + alpha * d_acc
                d_steps += 1

                # print(g_loss.item(), g_steps, d_loss.item(), d_steps)
                t1 = process_time()
                time_per_it.append((t1 - t0) / 3600)
                if len(time_per_it) > 100:
                    time_per_it.pop(0)

                if not its % log_interval:
                    running_mean_it_time = sum(time_per_it) / len(time_per_it)
                    time_rem = (max_its - its + 1) * running_mean_it_time
                    print(
                        "{}/{} -- {} G Steps -- G Loss {:.2f} -- G Acc {:.2f} -"
                        "- {} D Steps -- D Loss {:.2f} -- D Acc {:.2f} -"
                        "- {:.2f} D Success -- {:.1f} Hours remaing...".format(
                            its, max_its, g_steps, g_loss, gen_acc, d_steps,
                            d_loss, d_acc, discr_success_rate, time_rem))

                    for idx in range(images.size(0)):
                        output_path = 'outputs/training/'.format(epoch)
                        if not os.path.exists(output_path):
                            os.makedirs(output_path)

                        output_path += 'iteration_{:06d}_example_{}.jpg'.format(
                            its, idx)
                        utils.export_image([
                            images[idx, :, :, :], style_images[idx, :, :, :],
                            stylized_im[idx, :, :, :]
                        ], output_path)

                its += 1
                scheduler_g.step()
                scheduler_d.step()

                if not its % 10000:
                    if not os.path.exists('tmp'):
                        os.mkdir('tmp')
                    torch.save(encoder, "tmp/encoder.pt")
                    torch.save(decoder, "tmp/decoder.pt")
                    torch.save(tblock, "tmp/tblock.pt")
                    torch.save(discrim, "tmp/discriminator.pt")

        # only save if running on gpu (otherwise I'm just fixing bugs)
        torch.save(encoder, "encoder.pt")
        torch.save(decoder, "decoder.pt")
        torch.save(tblock, "tblock.pt")
        torch.save(discrim, "discriminator.pt")

        evaluate(encoder, decoder, dataloaders['test'])
    else:
        encoder = torch.load('encoder.pt', map_location='cpu')
        decoder = torch.load('decoder.pt', map_location='cpu')

        # encoder.load_state_dict(encoder_dict)
        # decoder.load_state_dict(decoder_dict)

        dataloader = DataLoader(datasets.TestDataset(),
                                batch_size=1,
                                shuffle=False,
                                num_workers=8)
        evaluate(encoder, decoder, dataloader)
        raise NotImplementedError('Not implemented standalone ')
예제 #10
0
 def __init__(self, ngf=64, n_blocks=16, use_weights=False):
     super(SRNTT, self).__init__()
     self.content_extractor = ContentExtractor(ngf, n_blocks)
     self.texture_transfer = TextureTransfer(ngf, n_blocks, use_weights)
     models.init_weights(self, init_type='normal', init_gain=0.02)
예제 #11
0
    def __init__(self,
                 in_channels=3,
                 n_classes=1,
                 feature_scale=4,
                 is_deconv=True,
                 is_batchnorm=True):
        super(UNet_3Plus_DeepSup_CGM, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]

        ## -------------Encoder--------------
        self.conv1 = conv2d_down_block(self.in_channels, filters[0],
                                       self.is_batchnorm)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = conv2d_down_block(filters[0], filters[1],
                                       self.is_batchnorm)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = conv2d_down_block(filters[1], filters[2],
                                       self.is_batchnorm)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = conv2d_down_block(filters[2], filters[3],
                                       self.is_batchnorm)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.conv5 = conv2d_down_block(filters[3], filters[4],
                                       self.is_batchnorm)

        ## -------------Decoder--------------
        self.CatChannels = filters[0]
        self.CatBlocks = 5
        self.UpChannels = self.CatChannels * self.CatBlocks
        '''stage 4d'''
        # h1->320*320, hd4->40*40, Pooling 8 times
        self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
        self.h1_PT_hd4_conv = nn.Conv2d(filters[0],
                                        self.CatChannels,
                                        3,
                                        padding=1)
        self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.h1_PT_hd4_relu = nn.ReLU(inplace=True)

        # h2->160*160, hd4->40*40, Pooling 4 times
        self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)
        self.h2_PT_hd4_conv = nn.Conv2d(filters[1],
                                        self.CatChannels,
                                        3,
                                        padding=1)
        self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.h2_PT_hd4_relu = nn.ReLU(inplace=True)

        # h3->80*80, hd4->40*40, Pooling 2 times
        self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.h3_PT_hd4_conv = nn.Conv2d(filters[2],
                                        self.CatChannels,
                                        3,
                                        padding=1)
        self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.h3_PT_hd4_relu = nn.ReLU(inplace=True)

        # h4->40*40, hd4->40*40, Concatenation
        self.h4_Cat_hd4_conv = nn.Conv2d(filters[3],
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd4->40*40, Upsample 2 times
        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd5_UT_hd4_conv = nn.Conv2d(filters[4],
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)

        # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
        self.conv4d_1 = nn.Conv2d(self.UpChannels,
                                  self.UpChannels,
                                  3,
                                  padding=1)  # 16
        self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)
        self.relu4d_1 = nn.ReLU(inplace=True)
        '''stage 3d'''
        # h1->320*320, hd3->80*80, Pooling 4 times
        self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
        self.h1_PT_hd3_conv = nn.Conv2d(filters[0],
                                        self.CatChannels,
                                        3,
                                        padding=1)
        self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.h1_PT_hd3_relu = nn.ReLU(inplace=True)

        # h2->160*160, hd3->80*80, Pooling 2 times
        self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.h2_PT_hd3_conv = nn.Conv2d(filters[1],
                                        self.CatChannels,
                                        3,
                                        padding=1)
        self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.h2_PT_hd3_relu = nn.ReLU(inplace=True)

        # h3->80*80, hd3->80*80, Concatenation
        self.h3_Cat_hd3_conv = nn.Conv2d(filters[2],
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)

        # hd4->40*40, hd4->80*80, Upsample 2 times
        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels,
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd4->80*80, Upsample 4 times
        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd5_UT_hd3_conv = nn.Conv2d(filters[4],
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)

        # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
        self.conv3d_1 = nn.Conv2d(self.UpChannels,
                                  self.UpChannels,
                                  3,
                                  padding=1)  # 16
        self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)
        self.relu3d_1 = nn.ReLU(inplace=True)
        '''stage 2d '''
        # h1->320*320, hd2->160*160, Pooling 2 times
        self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.h1_PT_hd2_conv = nn.Conv2d(filters[0],
                                        self.CatChannels,
                                        3,
                                        padding=1)
        self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.h1_PT_hd2_relu = nn.ReLU(inplace=True)

        # h2->160*160, hd2->160*160, Concatenation
        self.h2_Cat_hd2_conv = nn.Conv2d(filters[1],
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)

        # hd3->80*80, hd2->160*160, Upsample 2 times
        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels,
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)

        # hd4->40*40, hd2->160*160, Upsample 4 times
        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels,
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd2->160*160, Upsample 8 times
        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
        self.hd5_UT_hd2_conv = nn.Conv2d(filters[4],
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)

        # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
        self.conv2d_1 = nn.Conv2d(self.UpChannels,
                                  self.UpChannels,
                                  3,
                                  padding=1)  # 16
        self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)
        self.relu2d_1 = nn.ReLU(inplace=True)
        '''stage 1d'''
        # h1->320*320, hd1->320*320, Concatenation
        self.h1_Cat_hd1_conv = nn.Conv2d(filters[0],
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)

        # hd2->160*160, hd1->320*320, Upsample 2 times
        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear')  # 14*14
        self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels,
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)

        # hd3->80*80, hd1->320*320, Upsample 4 times
        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear')  # 14*14
        self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels,
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)

        # hd4->40*40, hd1->320*320, Upsample 8 times
        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear')  # 14*14
        self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels,
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)

        # hd5->20*20, hd1->320*320, Upsample 16 times
        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16,
                                      mode='bilinear')  # 14*14
        self.hd5_UT_hd1_conv = nn.Conv2d(filters[4],
                                         self.CatChannels,
                                         3,
                                         padding=1)
        self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
        self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)

        # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
        self.conv1d_1 = nn.Conv2d(self.UpChannels,
                                  self.UpChannels,
                                  3,
                                  padding=1)  # 16
        self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)
        self.relu1d_1 = nn.ReLU(inplace=True)

        # -------------Bilinear Upsampling--------------
        self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear')  ###
        self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear')
        self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear')
        self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear')
        self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')

        # DeepSup
        self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
        self.outconv2 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
        self.outconv3 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
        self.outconv4 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
        self.outconv5 = nn.Conv2d(filters[4], n_classes, 3, padding=1)

        self.cls = nn.Sequential(nn.Dropout(p=0.5),
                                 nn.Conv2d(filters[4], 2, 1),
                                 nn.AdaptiveMaxPool2d(1), nn.Sigmoid())

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')
예제 #12
0
def main():
    train = False
    resume = True
    input_size = 768
    artist = 'van-gogh'
    assert artist in artist_list
    # input_size = 256  # set to none for default cropping
    dual_optim = False
    print("Training with Places365 Dataset")
    max_its = 300000
    max_eps = 20000
    optimizer = 'adam'  # separate optimizers for discriminator and autoencoder
    lr = 0.0002
    batch_size = 1
    step_lr_gamma = 0.1
    step_lr_step = 200000
    discr_success_rate = 0.8
    win_rate = 0.8
    log_interval = int(max_its // 100)
    # log_interval = 100
    if log_interval < 10:
        print("\n WARNING: VERY SMALL LOG INTERVAL\n")

    lam = 0.001
    disc_wt = 1.
    trans_wt = 100.
    style_wt = 100.

    alpha = 0.05

    tblock_kernel = 10
    # Models
    encoder = models.Encoder()
    decoder = models.Decoder()
    tblock = models.TransformerBlock(kernel_size=tblock_kernel)
    discrim = models.Discriminator()

    # init weights
    models.init_weights(encoder)
    models.init_weights(decoder)
    models.init_weights(tblock)
    models.init_weights(discrim)

    if torch.cuda.is_available():
        encoder = encoder.cuda()
        decoder = decoder.cuda()
        tblock = tblock.cuda()
        discrim = discrim.cuda()

    artist_dir = glob.glob('../Datasets/WikiArt-Sorted/data/*')
    for item in artist_dir:
        if artist in os.path.basename(item):
            data_dir = item
            break
    print('Retrieving style examples from {} artwork from directory {}'.format(
        artist.upper(), data_dir))

    save_dir = 'outputs-{}'.format(artist)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    print('Saving weights and outputs to {}'.format(save_dir))

    if train:
        # load tmp weights
        if os.path.exists('tmp') and not resume:
            print('Loading from tmp...')
            assert os.path.exists("tmp/encoder.pt")
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            encoder = torch.load("tmp/encoder.pt", map_location=device)
            decoder = torch.load("tmp/decoder.pt", map_location=device)
            tblock = torch.load("tmp/tblock.pt", map_location=device)
            discrim = torch.load("tmp/discriminator.pt", map_location=device)

        # Losses
        gen_loss = losses.SoftmaxLoss()
        disc_loss = losses.SoftmaxLoss()
        transf_loss = losses.TransformedLoss()
        style_aware_loss = losses.StyleAwareContentLoss()

        if resume:
            print('Resuming training from {}...'.format(save_dir))
            assert os.path.exists(os.path.join(save_dir, 'encoder.pt'))
            lr *= step_lr_gamma
            max_its += 150000
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            encoder = torch.load(save_dir + "/encoder.pt", map_location=device)
            decoder = torch.load(save_dir + "/decoder.pt", map_location=device)
            tblock = torch.load(save_dir + "/tblock.pt", map_location=device)
            discrim = torch.load(save_dir + "/discriminator.pt",
                                 map_location=device)

        num_workers = 8
        dataloaders = {
            'train':
            DataLoader(datasets.PlacesDataset(train=True,
                                              input_size=input_size),
                       batch_size=batch_size,
                       shuffle=True,
                       num_workers=num_workers),
            'style':
            DataLoader(datasets.StyleDataset(data_dir=data_dir,
                                             input_size=input_size),
                       batch_size=batch_size,
                       shuffle=True,
                       num_workers=num_workers),
            'test':
            DataLoader(datasets.TestDataset(),
                       batch_size=1,
                       shuffle=False,
                       num_workers=num_workers)
        }

        # optimizer for encoder/decoder (and tblock? - think it has no parameters though)
        gen_params = []
        for m in [encoder, decoder]:
            for param in m.parameters():
                param.requires_grad = True
                gen_params.append(param)
        g_optimizer = torch.optim.Adam(gen_params, lr=lr)

        # optimizer for disciminator
        disc_params = []
        for param in discrim.parameters():
            param.requires_grad = True
            disc_params.append(param)
        d_optimizer = torch.optim.Adam(disc_params, lr=lr)

        scheduler_g = torch.optim.lr_scheduler.StepLR(g_optimizer,
                                                      step_lr_step,
                                                      gamma=step_lr_gamma,
                                                      last_epoch=-1)
        scheduler_d = torch.optim.lr_scheduler.StepLR(d_optimizer,
                                                      step_lr_step,
                                                      gamma=step_lr_gamma,
                                                      last_epoch=-1)

        its = 300000 if resume else 0
        print('Begin Training:')
        g_steps = 0
        d_steps = 0
        image_id = 0
        time_per_it = []
        if max_its is None:
            max_its = len(dataloaders['train'])

        # set models to train()
        encoder.train()
        decoder.train()
        tblock.train()
        discrim.train()

        d_loss = 0
        g_loss = 0
        gen_acc = 0
        d_acc = 0
        for epoch in range(max_eps):
            if its > max_its:
                break
            for images, style_images in zip(dataloaders['train'],
                                            cycle(dataloaders['style'])):
                t0 = process_time()
                # utils.export_image(images[0, :, :, :], style_images[0, :, :, :], 'input_images.jpg')

                # zero gradients
                g_optimizer.zero_grad()
                d_optimizer.zero_grad()

                if its > max_its:
                    break

                if torch.cuda.is_available():
                    images = images.cuda()
                    if style_images is not None:
                        style_images = style_images.cuda()

                # autoencoder
                emb = encoder(images)
                stylized_im = decoder(emb)

                # if training do losses etc
                stylized_emb = encoder(stylized_im)

                if discr_success_rate < win_rate:
                    # discriminator train step
                    # discriminator
                    # detach from generator, so not propagating unnecessary gradients
                    d_out_fake = discrim(stylized_im.clone().detach())
                    d_out_real_ph = discrim(images)
                    d_out_real_style = discrim(style_images)

                    # accuracy given all the images
                    d_acc_real_ph = utils.accuracy(d_out_real_ph,
                                                   target_label=0)
                    d_acc_fake_style = utils.accuracy(d_out_fake,
                                                      target_label=0)
                    d_acc_real_style = utils.accuracy(d_out_real_style,
                                                      target_label=1)
                    gen_acc = 1 - d_acc_fake_style
                    d_acc = (d_acc_real_ph + d_acc_fake_style +
                             d_acc_real_style) / 3

                    # Loss calculation
                    d_loss = disc_loss(d_out_fake, target_label=0)
                    d_loss += disc_loss(d_out_real_style, target_label=1)
                    d_loss += disc_loss(d_out_real_ph, target_label=0)

                    # Step optimizer
                    d_loss.backward()
                    d_optimizer.step()
                    d_steps += 1

                    # Update success rate
                    discr_success_rate = discr_success_rate * (
                        1. - alpha) + alpha * d_acc
                else:
                    # generator train step
                    # Generator discrim losses

                    # discriminator
                    d_out_fake = discrim(
                        stylized_im
                    )  # keep attached to generator because grads needed

                    # accuracy given the fake output, generator images
                    gen_acc = utils.accuracy(
                        d_out_fake,
                        target_label=1)  # accuracy given only the output image

                    del g_loss
                    # tblock
                    transformed_inputs, transformed_outputs = tblock(
                        images, stylized_im)

                    g_loss = disc_wt * gen_loss(d_out_fake, target_label=1)
                    g_transf = trans_wt * transf_loss(transformed_inputs,
                                                      transformed_outputs)
                    g_style = style_wt * style_aware_loss(emb, stylized_emb)
                    # print(g_loss.item(), g_transf.item(), g_style.item())
                    g_loss += g_transf + g_style

                    # STEP OPTIMIZER
                    g_loss.backward()
                    g_optimizer.step()
                    g_steps += 1

                    # Update success rate
                    discr_success_rate = discr_success_rate * (
                        1. - alpha) + alpha * (1. - gen_acc)

                # report stuff
                t1 = process_time()
                time_per_it.append((t1 - t0) / 3600)
                if len(time_per_it) > 100:
                    time_per_it.pop(0)

                if not its % log_interval:
                    running_mean_it_time = sum(time_per_it) / len(time_per_it)
                    time_rem = (max_its - its + 1) * running_mean_it_time
                    print(
                        "{}/{} -- {} G Steps -- G Loss {:.2f} -- G Acc {:.2f} -"
                        "- {} D Steps -- D Loss {:.2f} -- D Acc {:.2f} -"
                        "- {:.2f} D Success -- {:.1f} Hours remaing...".format(
                            its, max_its, g_steps, g_loss, gen_acc, d_steps,
                            d_loss, d_acc, discr_success_rate, time_rem))

                    for idx in range(images.size(0)):
                        output_path = os.path.join(save_dir,
                                                   'training_visualise')
                        if not os.path.exists(output_path):
                            os.makedirs(output_path)

                        output_path = os.path.join(
                            output_path,
                            'iteration_{:06d}_example_{}.jpg'.format(its, idx))
                        utils.export_image([
                            images[idx, :, :, :], style_images[idx, :, :, :],
                            stylized_im[idx, :, :, :]
                        ], output_path)

                its += 1
                scheduler_g.step()
                scheduler_d.step()

                if not its % 10000:
                    if not os.path.exists(save_dir):
                        os.mkdir(save_dir)
                    torch.save(encoder, save_dir + "/encoder.pt")
                    torch.save(decoder, save_dir + "/decoder.pt")
                    torch.save(tblock, save_dir + "/tblock.pt")
                    torch.save(discrim, save_dir + "/discriminator.pt")

        # only save if running on gpu (otherwise I'm just fixing bugs)
        torch.save(encoder, os.path.join(save_dir, "encoder.pt"))
        torch.save(decoder, os.path.join(save_dir, "decoder.pt"))
        torch.save(tblock, os.path.join(save_dir, "tblock.pt"))
        torch.save(discrim, os.path.join(save_dir, "discriminator.pt"))

        evaluate(encoder, decoder, dataloaders['test'], save_dir=save_dir)
    else:
        print('Loading Models {} and {}'.format(
            os.path.join(save_dir, "encoder.pt"),
            os.path.join(save_dir, "decoder.pt")))
        encoder = torch.load(os.path.join(save_dir, "encoder.pt"),
                             map_location='cpu')
        decoder = torch.load(os.path.join(save_dir, "decoder.pt"),
                             map_location='cpu')

        # encoder.load_state_dict(encoder_dict)
        # decoder.load_state_dict(decoder_dict)
        if torch.cuda.is_available():
            encoder = encoder.cuda()
            decoder = decoder.cuda()

        dataloader = DataLoader(datasets.TestDataset(input_size=input_size),
                                batch_size=1,
                                shuffle=False,
                                num_workers=8)
        evaluate(encoder, decoder, dataloader, save_dir=save_dir)