def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): super(conv2d_down_block, self).__init__() self.n = n self.ks = ks self.stride = stride self.padding = padding s = stride p = padding if is_batchnorm: for i in range(1, n + 1): conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size else: for i in range(1, n + 1): conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True): super(UNet, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale # # filters = [32, 64, 128, 256, 512] filters = [64, 128, 256, 512, 1024] # # filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = conv2d_down_block(self.in_channels, filters[0], self.is_batchnorm) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = conv2d_down_block(filters[0], filters[1], self.is_batchnorm) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = conv2d_down_block(filters[1], filters[2], self.is_batchnorm) self.pool3 = nn.MaxPool2d(kernel_size=2) self.conv4 = conv2d_down_block(filters[2], filters[3], self.is_batchnorm) self.pool4 = nn.MaxPool2d(kernel_size=2) self.center = conv2d_down_block(filters[3], filters[4], self.is_batchnorm) # upsampling self.up_concat4 = conv2d_up_block(filters[4], filters[3], self.is_deconv) self.up_concat3 = conv2d_up_block(filters[3], filters[2], self.is_deconv) self.up_concat2 = conv2d_up_block(filters[2], filters[1], self.is_deconv) self.up_concat1 = conv2d_up_block(filters[1], filters[0], self.is_deconv) # self.outconv1 = nn.Conv2d(filters[0], 1, 3, padding=1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_deconv, n_concat=2): super(conv2d_up_block, self).__init__() self.conv = conv2d_down_block(out_size * 2, out_size, False) if is_deconv: self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) else: self.up = nn.UpsamplingBilinear2d(scale_factor=2) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('conv2d_down_block') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, ndf=32): def conv_block(in_channels, out_channels): block = [ nn.Conv2d(in_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels), nn.LeakyReLU(0.2, True), nn.Conv2d(out_channels, out_channels, 3, 2, 1), nn.BatchNorm2d(out_channels), nn.LeakyReLU(0.2, True), ] return block super(Discriminator, self).__init__(*conv_block(3, ndf), *conv_block(ndf, ndf * 2), *conv_block(ndf * 2, ndf * 4), *conv_block(ndf * 4, ndf * 8), *conv_block(ndf * 8, ndf * 16), nn.Conv2d(ndf * 16, 1024, kernel_size=1), nn.LeakyReLU(0.2), nn.Conv2d(1024, 1, kernel_size=1), nn.Sigmoid()) models.init_weights(self, init_type='normal', init_gain=0.02)
def __init__(self, d_state, d_action, n_layers, n_units, activation): super().__init__() assert n_layers >= 1 layers = [nn.Linear(d_state, n_units), get_activation(activation)] for _ in range(1, n_layers): layers += [nn.Linear(n_units, n_units), get_activation(activation)] layers += [nn.Linear(n_units, d_action)] [ init_weights(layer) for layer in layers if isinstance(layer, nn.Linear) ] self.layers = nn.Sequential(*layers)
def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True, is_ds=True): super(UNet_2Plus, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.is_ds = is_ds self.feature_scale = feature_scale # filters = [32, 64, 128, 256, 512] filters = [64, 128, 256, 512, 1024] # filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv00 = conv2d_down_block(self.in_channels, filters[0], self.is_batchnorm) self.maxpool0 = nn.MaxPool2d(kernel_size=2) self.conv10 = conv2d_down_block(filters[0], filters[1], self.is_batchnorm) self.maxpool1 = nn.MaxPool2d(kernel_size=2) self.conv20 = conv2d_down_block(filters[1], filters[2], self.is_batchnorm) self.maxpool2 = nn.MaxPool2d(kernel_size=2) self.conv30 = conv2d_down_block(filters[2], filters[3], self.is_batchnorm) self.maxpool3 = nn.MaxPool2d(kernel_size=2) self.conv40 = conv2d_down_block(filters[3], filters[4], self.is_batchnorm) # upsampling self.up_concat01 = unetUp_origin(filters[1], filters[0], self.is_deconv) self.up_concat11 = unetUp_origin(filters[2], filters[1], self.is_deconv) self.up_concat21 = unetUp_origin(filters[3], filters[2], self.is_deconv) self.up_concat31 = unetUp_origin(filters[4], filters[3], self.is_deconv) self.up_concat02 = unetUp_origin(filters[1], filters[0], self.is_deconv, 3) self.up_concat12 = unetUp_origin(filters[2], filters[1], self.is_deconv, 3) self.up_concat22 = unetUp_origin(filters[3], filters[2], self.is_deconv, 3) self.up_concat03 = unetUp_origin(filters[1], filters[0], self.is_deconv, 4) self.up_concat13 = unetUp_origin(filters[2], filters[1], self.is_deconv, 4) self.up_concat04 = unetUp_origin(filters[1], filters[0], self.is_deconv, 5) # final conv (without any concat) self.final_1 = nn.Conv2d(filters[0], n_classes, 1) self.final_2 = nn.Conv2d(filters[0], n_classes, 1) self.final_3 = nn.Conv2d(filters[0], n_classes, 1) self.final_4 = nn.Conv2d(filters[0], n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type='kaiming')
def train_one_folder(opt, folder): # Use specific GPU device = torch.device(opt.gpu_num) opt.folder = folder # Dataloaders train_dataset_file_path = os.path.join('../dataset', opt.source_domain, str(opt.folder), 'train.csv') train_loader = get_dataloader(train_dataset_file_path, 'train', opt) test_dataset_file_path = os.path.join('../dataset', opt.source_domain, str(opt.folder), 'test.csv') test_loader = get_dataloader(test_dataset_file_path, 'test', opt) # Model, optimizer and loss function emotion_recognizer = models.Model(opt) models.init_weights(emotion_recognizer) for param in emotion_recognizer.parameters(): param.requires_grad = True emotion_recognizer.to(device) optimizer = torch.optim.Adam(emotion_recognizer.parameters(), lr=opt.learning_rate) lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1) criterion = torch.nn.CrossEntropyLoss() best_acc = 0. best_uar = 0. es = EarlyStopping(patience=opt.patience) # Train and validate for epoch in range(opt.epochs_num): if opt.verbose: print('epoch: {}/{}'.format(epoch + 1, opt.epochs_num)) train_loss, train_acc = train(train_loader, emotion_recognizer, optimizer, criterion, device, opt) test_loss, test_acc, test_uar = test(test_loader, emotion_recognizer, criterion, device, opt) if opt.verbose: print('train_loss: {0:.5f}'.format(train_loss), 'train_acc: {0:.3f}'.format(train_acc), 'test_loss: {0:.5f}'.format(test_loss), 'test_acc: {0:.3f}'.format(test_acc), 'test_uar: {0:.3f}'.format(test_uar)) lr_schedule.step(test_loss) os.makedirs(os.path.join(opt.logger_path, opt.source_domain), exist_ok=True) model_file_name = os.path.join(opt.logger_path, opt.source_domain, 'checkpoint.pth.tar') state = { 'epoch': epoch + 1, 'emotion_recognizer': emotion_recognizer.state_dict(), 'opt': opt } torch.save(state, model_file_name) if test_acc > best_acc: model_file_name = os.path.join(opt.logger_path, opt.source_domain, 'model.pth.tar') torch.save(state, model_file_name) best_acc = test_acc if test_uar > best_uar: best_uar = test_uar if es.step(test_loss): break return best_acc, best_uar
# n_channels=3 for RGB images # n_classes is the number of probabilities you want to get per pixel # - For 1 class and background, use n_classes=1 # - For 2 classes, use n_classes=1 # - For N > 2 classes, use n_classes=N net = ChooseModel(network)(n_channels=3, n_classes=conf['DATASET']['NUM_CLASSES']) assert net is not None, f'check your argument --network' logging.info( f'Network:\n' f'\t{net.n_channels} input channels\n' f'\t{net.n_classes} output channels (classes)\n' f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling\n' f'\tApex is {"using" if args.use_apex == "True" else "not using"}') init_weights(net, args.init_type) if args.load: net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f'Model loaded from {args.load}') net.to(device=device) # faster convolutions, but more memory # cudnn.benchmark = True try: train_net(net=net, epochs=args.epochs, batch_size=args.batchsize, lr=args.lr, device=device, img_scale=args.scale,
def main(): train = True input_size = 768 # input_size = 256 # set to none for default cropping print("Training with Places365 Dataset") max_its = 300000 max_eps = 20000 optimizer = 'adam' # separate optimizers for discriminator and autoencoder lr = 0.0002 batch_size = 1 step_lr_gamma = 0.1 step_lr_step = 200000 discr_success_rate = 0.8 win_rate = 0.8 log_interval = int(max_its // 100) # log_interval = 100 if log_interval < 10: print("\n WARNING: VERY SMALL LOG INTERVAL\n") lam = 0.001 disc_wt = 1. trans_wt = 100. style_wt = 100. alpha = 0.05 tblock_kernel = 10 # Models encoder = models.Encoder() decoder = models.Decoder() tblock = models.TransformerBlock(kernel_size=tblock_kernel) discrim = models.Discriminator() # init weights models.init_weights(encoder) models.init_weights(decoder) models.init_weights(tblock) models.init_weights(discrim) if torch.cuda.is_available(): encoder = encoder.cuda() decoder = decoder.cuda() tblock = tblock.cuda() discrim = discrim.cuda() if train: # load tmp weights if os.path.exists('tmp'): device = 'cuda' if torch.cuda.is_available() else 'cpu' encoder = torch.load("tmp/encoder.pt", map_location=device) decoder = torch.load("tmp/decoder.pt", map_location=device) tblock = torch.load("tmp/tblock.pt", map_location=device) discrim = torch.load("tmp/discriminator.pt", map_location=device) # Losses gen_loss = losses.SoftmaxLoss() disc_loss = losses.SoftmaxLoss() transf_loss = losses.TransformedLoss() style_aware_loss = losses.StyleAwareContentLoss() # # optimizer for encoder/decoder (and tblock? - think it has no parameters though) # params_to_update = [] # for m in [encoder, decoder, tblock, discrim]: # for param in m.parameters(): # param.requires_grad = True # params_to_update.append(param) # # optimizer = torch.optim.Adam(params_to_update, lr=lr) data_dir = '../Datasets/WikiArt-Sorted/data/vincent-van-gogh_road-with-cypresses-1890' # data_dir = '../Datasets/WikiArt-Sorted/data/edvard-munch/' style_data = datasets.StyleDataset(data_dir) num_workers = 8 # if mpii: # dataloaders = {'train': DataLoader(datasets.MpiiDataset(train=True, input_size=input_size, # style_dataset=style_data, crop_size=crop_size), # batch_size=batch_size, shuffle=True, num_workers=num_workers), # 'test': DataLoader(datasets.MpiiDataset(train=False, style_dataset=style_data, input_size=input_size), # batch_size=1, shuffle=False, num_workers=num_workers)} # else: dataloaders = { 'train': DataLoader(datasets.PlacesDataset(train=True, input_size=input_size, style_dataset=style_data), batch_size=batch_size, shuffle=True, num_workers=num_workers), 'test': DataLoader(datasets.TestDataset(), batch_size=1, shuffle=False, num_workers=num_workers) } # optimizer for encoder/decoder (and tblock? - think it has no parameters though) gen_params = [] for m in [encoder, decoder]: for param in m.parameters(): param.requires_grad = True gen_params.append(param) g_optimizer = torch.optim.Adam(gen_params, lr=lr) # optimizer for disciminator disc_params = [] for param in discrim.parameters(): param.requires_grad = True disc_params.append(param) d_optimizer = torch.optim.Adam(disc_params, lr=lr) scheduler_g = torch.optim.lr_scheduler.StepLR(g_optimizer, step_lr_step, gamma=step_lr_gamma, last_epoch=-1) scheduler_d = torch.optim.lr_scheduler.StepLR(d_optimizer, step_lr_step, gamma=step_lr_gamma, last_epoch=-1) its = 0 print('Begin Training:') g_steps = 0 d_steps = 0 image_id = 0 time_per_it = [] if max_its is None: max_its = len(dataloaders['train']) # set models to train() encoder.train() decoder.train() tblock.train() discrim.train() d_loss = 0 g_loss = 0 gen_acc = 0 d_acc = 0 for epoch in range(max_eps): if its > max_its: break for images, style_images in dataloaders['train']: t0 = process_time() # utils.export_image(images[0, :, :, :], style_images[0, :, :, :], 'input_images.jpg') # zero gradients g_optimizer.zero_grad() d_optimizer.zero_grad() if its > max_its: break if torch.cuda.is_available(): images = images.cuda() if style_images is not None: style_images = style_images.cuda() # autoencoder emb = encoder(images) stylized_im = decoder(emb) # if training do losses etc stylized_emb = encoder(stylized_im) # add losses # tblock transformed_inputs, transformed_outputs = tblock( images, stylized_im) # add loss # # # GENERATOR TRAIN # # # # g_optimizer.zero_grad() d_out_fake = discrim( stylized_im ) # keep attached to generator because grads needed # accuracy given the fake output, generator images gen_acc = utils.accuracy( d_out_fake, target_label=1) # accuracy given only the output image del g_loss g_loss = disc_wt * gen_loss(d_out_fake, target_label=1) g_loss += trans_wt * transf_loss(transformed_inputs, transformed_outputs) g_loss += style_wt * style_aware_loss(emb, stylized_emb) g_loss.backward() d_optimizer.step() discr_success_rate = discr_success_rate * ( 1. - alpha) + alpha * (1. - gen_acc) g_steps += 1 # # # DISCRIMINATOR TRAIN # # # # d_optimizer.zero_grad() # detach from generator, so not propagating unnecessary gradients d_out_fake = discrim(stylized_im.clone().detach()) d_out_real_ph = discrim(images) d_out_real_style = discrim(style_images) # accuracy given all the images d_acc_real_ph = utils.accuracy(d_out_real_ph, target_label=0) d_acc_fake_style = utils.accuracy(d_out_fake, target_label=0) d_acc_real_style = utils.accuracy(d_out_real_style, target_label=1) gen_acc = 1 - d_acc_fake_style d_acc = (d_acc_real_ph + d_acc_fake_style + d_acc_real_style) / 3 # Loss calculation d_loss = disc_loss(d_out_fake, target_label=0) d_loss += disc_loss(d_out_real_style, target_label=1) d_loss += disc_loss(d_out_real_ph, target_label=0) d_loss.backward() d_optimizer.step() discr_success_rate = discr_success_rate * ( 1. - alpha) + alpha * d_acc d_steps += 1 # print(g_loss.item(), g_steps, d_loss.item(), d_steps) t1 = process_time() time_per_it.append((t1 - t0) / 3600) if len(time_per_it) > 100: time_per_it.pop(0) if not its % log_interval: running_mean_it_time = sum(time_per_it) / len(time_per_it) time_rem = (max_its - its + 1) * running_mean_it_time print( "{}/{} -- {} G Steps -- G Loss {:.2f} -- G Acc {:.2f} -" "- {} D Steps -- D Loss {:.2f} -- D Acc {:.2f} -" "- {:.2f} D Success -- {:.1f} Hours remaing...".format( its, max_its, g_steps, g_loss, gen_acc, d_steps, d_loss, d_acc, discr_success_rate, time_rem)) for idx in range(images.size(0)): output_path = 'outputs/training/'.format(epoch) if not os.path.exists(output_path): os.makedirs(output_path) output_path += 'iteration_{:06d}_example_{}.jpg'.format( its, idx) utils.export_image([ images[idx, :, :, :], style_images[idx, :, :, :], stylized_im[idx, :, :, :] ], output_path) its += 1 scheduler_g.step() scheduler_d.step() if not its % 10000: if not os.path.exists('tmp'): os.mkdir('tmp') torch.save(encoder, "tmp/encoder.pt") torch.save(decoder, "tmp/decoder.pt") torch.save(tblock, "tmp/tblock.pt") torch.save(discrim, "tmp/discriminator.pt") # only save if running on gpu (otherwise I'm just fixing bugs) torch.save(encoder, "encoder.pt") torch.save(decoder, "decoder.pt") torch.save(tblock, "tblock.pt") torch.save(discrim, "discriminator.pt") evaluate(encoder, decoder, dataloaders['test']) else: encoder = torch.load('encoder.pt', map_location='cpu') decoder = torch.load('decoder.pt', map_location='cpu') # encoder.load_state_dict(encoder_dict) # decoder.load_state_dict(decoder_dict) dataloader = DataLoader(datasets.TestDataset(), batch_size=1, shuffle=False, num_workers=8) evaluate(encoder, decoder, dataloader) raise NotImplementedError('Not implemented standalone ')
def __init__(self, ngf=64, n_blocks=16, use_weights=False): super(SRNTT, self).__init__() self.content_extractor = ContentExtractor(ngf, n_blocks) self.texture_transfer = TextureTransfer(ngf, n_blocks, use_weights) models.init_weights(self, init_type='normal', init_gain=0.02)
def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True): super(UNet_3Plus_DeepSup_CGM, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] ## -------------Encoder-------------- self.conv1 = conv2d_down_block(self.in_channels, filters[0], self.is_batchnorm) self.maxpool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = conv2d_down_block(filters[0], filters[1], self.is_batchnorm) self.maxpool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = conv2d_down_block(filters[1], filters[2], self.is_batchnorm) self.maxpool3 = nn.MaxPool2d(kernel_size=2) self.conv4 = conv2d_down_block(filters[2], filters[3], self.is_batchnorm) self.maxpool4 = nn.MaxPool2d(kernel_size=2) self.conv5 = conv2d_down_block(filters[3], filters[4], self.is_batchnorm) ## -------------Decoder-------------- self.CatChannels = filters[0] self.CatBlocks = 5 self.UpChannels = self.CatChannels * self.CatBlocks '''stage 4d''' # h1->320*320, hd4->40*40, Pooling 8 times self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True) self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd4_relu = nn.ReLU(inplace=True) # h2->160*160, hd4->40*40, Pooling 4 times self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True) self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h2_PT_hd4_relu = nn.ReLU(inplace=True) # h3->80*80, hd4->40*40, Pooling 2 times self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h3_PT_hd4_relu = nn.ReLU(inplace=True) # h4->40*40, hd4->40*40, Concatenation self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1) self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.h4_Cat_hd4_relu = nn.ReLU(inplace=True) # hd5->20*20, hd4->40*40, Upsample 2 times self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd4_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4) self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn4d_1 = nn.BatchNorm2d(self.UpChannels) self.relu4d_1 = nn.ReLU(inplace=True) '''stage 3d''' # h1->320*320, hd3->80*80, Pooling 4 times self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True) self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd3_relu = nn.ReLU(inplace=True) # h2->160*160, hd3->80*80, Pooling 2 times self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h2_PT_hd3_relu = nn.ReLU(inplace=True) # h3->80*80, hd3->80*80, Concatenation self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1) self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.h3_Cat_hd3_relu = nn.ReLU(inplace=True) # hd4->40*40, hd4->80*80, Upsample 2 times self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd3_relu = nn.ReLU(inplace=True) # hd5->20*20, hd4->80*80, Upsample 4 times self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd3_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3) self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn3d_1 = nn.BatchNorm2d(self.UpChannels) self.relu3d_1 = nn.ReLU(inplace=True) '''stage 2d ''' # h1->320*320, hd2->160*160, Pooling 2 times self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True) self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.h1_PT_hd2_relu = nn.ReLU(inplace=True) # h2->160*160, hd2->160*160, Concatenation self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1) self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.h2_Cat_hd2_relu = nn.ReLU(inplace=True) # hd3->80*80, hd2->160*160, Upsample 2 times self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd3_UT_hd2_relu = nn.ReLU(inplace=True) # hd4->40*40, hd2->160*160, Upsample 4 times self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd2_relu = nn.ReLU(inplace=True) # hd5->20*20, hd2->160*160, Upsample 8 times self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14 self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd2_relu = nn.ReLU(inplace=True) # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2) self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn2d_1 = nn.BatchNorm2d(self.UpChannels) self.relu2d_1 = nn.ReLU(inplace=True) '''stage 1d''' # h1->320*320, hd1->320*320, Concatenation self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1) self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.h1_Cat_hd1_relu = nn.ReLU(inplace=True) # hd2->160*160, hd1->320*320, Upsample 2 times self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14 self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd2_UT_hd1_relu = nn.ReLU(inplace=True) # hd3->80*80, hd1->320*320, Upsample 4 times self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14 self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd3_UT_hd1_relu = nn.ReLU(inplace=True) # hd4->40*40, hd1->320*320, Upsample 8 times self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14 self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1) self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd4_UT_hd1_relu = nn.ReLU(inplace=True) # hd5->20*20, hd1->320*320, Upsample 16 times self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear') # 14*14 self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1) self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels) self.hd5_UT_hd1_relu = nn.ReLU(inplace=True) # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1) self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 self.bn1d_1 = nn.BatchNorm2d(self.UpChannels) self.relu1d_1 = nn.ReLU(inplace=True) # -------------Bilinear Upsampling-------------- self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear') ### self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear') self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear') self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear') self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') # DeepSup self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) self.outconv2 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) self.outconv3 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) self.outconv4 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1) self.outconv5 = nn.Conv2d(filters[4], n_classes, 3, padding=1) self.cls = nn.Sequential(nn.Dropout(p=0.5), nn.Conv2d(filters[4], 2, 1), nn.AdaptiveMaxPool2d(1), nn.Sigmoid()) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type='kaiming')
def main(): train = False resume = True input_size = 768 artist = 'van-gogh' assert artist in artist_list # input_size = 256 # set to none for default cropping dual_optim = False print("Training with Places365 Dataset") max_its = 300000 max_eps = 20000 optimizer = 'adam' # separate optimizers for discriminator and autoencoder lr = 0.0002 batch_size = 1 step_lr_gamma = 0.1 step_lr_step = 200000 discr_success_rate = 0.8 win_rate = 0.8 log_interval = int(max_its // 100) # log_interval = 100 if log_interval < 10: print("\n WARNING: VERY SMALL LOG INTERVAL\n") lam = 0.001 disc_wt = 1. trans_wt = 100. style_wt = 100. alpha = 0.05 tblock_kernel = 10 # Models encoder = models.Encoder() decoder = models.Decoder() tblock = models.TransformerBlock(kernel_size=tblock_kernel) discrim = models.Discriminator() # init weights models.init_weights(encoder) models.init_weights(decoder) models.init_weights(tblock) models.init_weights(discrim) if torch.cuda.is_available(): encoder = encoder.cuda() decoder = decoder.cuda() tblock = tblock.cuda() discrim = discrim.cuda() artist_dir = glob.glob('../Datasets/WikiArt-Sorted/data/*') for item in artist_dir: if artist in os.path.basename(item): data_dir = item break print('Retrieving style examples from {} artwork from directory {}'.format( artist.upper(), data_dir)) save_dir = 'outputs-{}'.format(artist) if not os.path.exists(save_dir): os.mkdir(save_dir) print('Saving weights and outputs to {}'.format(save_dir)) if train: # load tmp weights if os.path.exists('tmp') and not resume: print('Loading from tmp...') assert os.path.exists("tmp/encoder.pt") device = 'cuda' if torch.cuda.is_available() else 'cpu' encoder = torch.load("tmp/encoder.pt", map_location=device) decoder = torch.load("tmp/decoder.pt", map_location=device) tblock = torch.load("tmp/tblock.pt", map_location=device) discrim = torch.load("tmp/discriminator.pt", map_location=device) # Losses gen_loss = losses.SoftmaxLoss() disc_loss = losses.SoftmaxLoss() transf_loss = losses.TransformedLoss() style_aware_loss = losses.StyleAwareContentLoss() if resume: print('Resuming training from {}...'.format(save_dir)) assert os.path.exists(os.path.join(save_dir, 'encoder.pt')) lr *= step_lr_gamma max_its += 150000 device = 'cuda' if torch.cuda.is_available() else 'cpu' encoder = torch.load(save_dir + "/encoder.pt", map_location=device) decoder = torch.load(save_dir + "/decoder.pt", map_location=device) tblock = torch.load(save_dir + "/tblock.pt", map_location=device) discrim = torch.load(save_dir + "/discriminator.pt", map_location=device) num_workers = 8 dataloaders = { 'train': DataLoader(datasets.PlacesDataset(train=True, input_size=input_size), batch_size=batch_size, shuffle=True, num_workers=num_workers), 'style': DataLoader(datasets.StyleDataset(data_dir=data_dir, input_size=input_size), batch_size=batch_size, shuffle=True, num_workers=num_workers), 'test': DataLoader(datasets.TestDataset(), batch_size=1, shuffle=False, num_workers=num_workers) } # optimizer for encoder/decoder (and tblock? - think it has no parameters though) gen_params = [] for m in [encoder, decoder]: for param in m.parameters(): param.requires_grad = True gen_params.append(param) g_optimizer = torch.optim.Adam(gen_params, lr=lr) # optimizer for disciminator disc_params = [] for param in discrim.parameters(): param.requires_grad = True disc_params.append(param) d_optimizer = torch.optim.Adam(disc_params, lr=lr) scheduler_g = torch.optim.lr_scheduler.StepLR(g_optimizer, step_lr_step, gamma=step_lr_gamma, last_epoch=-1) scheduler_d = torch.optim.lr_scheduler.StepLR(d_optimizer, step_lr_step, gamma=step_lr_gamma, last_epoch=-1) its = 300000 if resume else 0 print('Begin Training:') g_steps = 0 d_steps = 0 image_id = 0 time_per_it = [] if max_its is None: max_its = len(dataloaders['train']) # set models to train() encoder.train() decoder.train() tblock.train() discrim.train() d_loss = 0 g_loss = 0 gen_acc = 0 d_acc = 0 for epoch in range(max_eps): if its > max_its: break for images, style_images in zip(dataloaders['train'], cycle(dataloaders['style'])): t0 = process_time() # utils.export_image(images[0, :, :, :], style_images[0, :, :, :], 'input_images.jpg') # zero gradients g_optimizer.zero_grad() d_optimizer.zero_grad() if its > max_its: break if torch.cuda.is_available(): images = images.cuda() if style_images is not None: style_images = style_images.cuda() # autoencoder emb = encoder(images) stylized_im = decoder(emb) # if training do losses etc stylized_emb = encoder(stylized_im) if discr_success_rate < win_rate: # discriminator train step # discriminator # detach from generator, so not propagating unnecessary gradients d_out_fake = discrim(stylized_im.clone().detach()) d_out_real_ph = discrim(images) d_out_real_style = discrim(style_images) # accuracy given all the images d_acc_real_ph = utils.accuracy(d_out_real_ph, target_label=0) d_acc_fake_style = utils.accuracy(d_out_fake, target_label=0) d_acc_real_style = utils.accuracy(d_out_real_style, target_label=1) gen_acc = 1 - d_acc_fake_style d_acc = (d_acc_real_ph + d_acc_fake_style + d_acc_real_style) / 3 # Loss calculation d_loss = disc_loss(d_out_fake, target_label=0) d_loss += disc_loss(d_out_real_style, target_label=1) d_loss += disc_loss(d_out_real_ph, target_label=0) # Step optimizer d_loss.backward() d_optimizer.step() d_steps += 1 # Update success rate discr_success_rate = discr_success_rate * ( 1. - alpha) + alpha * d_acc else: # generator train step # Generator discrim losses # discriminator d_out_fake = discrim( stylized_im ) # keep attached to generator because grads needed # accuracy given the fake output, generator images gen_acc = utils.accuracy( d_out_fake, target_label=1) # accuracy given only the output image del g_loss # tblock transformed_inputs, transformed_outputs = tblock( images, stylized_im) g_loss = disc_wt * gen_loss(d_out_fake, target_label=1) g_transf = trans_wt * transf_loss(transformed_inputs, transformed_outputs) g_style = style_wt * style_aware_loss(emb, stylized_emb) # print(g_loss.item(), g_transf.item(), g_style.item()) g_loss += g_transf + g_style # STEP OPTIMIZER g_loss.backward() g_optimizer.step() g_steps += 1 # Update success rate discr_success_rate = discr_success_rate * ( 1. - alpha) + alpha * (1. - gen_acc) # report stuff t1 = process_time() time_per_it.append((t1 - t0) / 3600) if len(time_per_it) > 100: time_per_it.pop(0) if not its % log_interval: running_mean_it_time = sum(time_per_it) / len(time_per_it) time_rem = (max_its - its + 1) * running_mean_it_time print( "{}/{} -- {} G Steps -- G Loss {:.2f} -- G Acc {:.2f} -" "- {} D Steps -- D Loss {:.2f} -- D Acc {:.2f} -" "- {:.2f} D Success -- {:.1f} Hours remaing...".format( its, max_its, g_steps, g_loss, gen_acc, d_steps, d_loss, d_acc, discr_success_rate, time_rem)) for idx in range(images.size(0)): output_path = os.path.join(save_dir, 'training_visualise') if not os.path.exists(output_path): os.makedirs(output_path) output_path = os.path.join( output_path, 'iteration_{:06d}_example_{}.jpg'.format(its, idx)) utils.export_image([ images[idx, :, :, :], style_images[idx, :, :, :], stylized_im[idx, :, :, :] ], output_path) its += 1 scheduler_g.step() scheduler_d.step() if not its % 10000: if not os.path.exists(save_dir): os.mkdir(save_dir) torch.save(encoder, save_dir + "/encoder.pt") torch.save(decoder, save_dir + "/decoder.pt") torch.save(tblock, save_dir + "/tblock.pt") torch.save(discrim, save_dir + "/discriminator.pt") # only save if running on gpu (otherwise I'm just fixing bugs) torch.save(encoder, os.path.join(save_dir, "encoder.pt")) torch.save(decoder, os.path.join(save_dir, "decoder.pt")) torch.save(tblock, os.path.join(save_dir, "tblock.pt")) torch.save(discrim, os.path.join(save_dir, "discriminator.pt")) evaluate(encoder, decoder, dataloaders['test'], save_dir=save_dir) else: print('Loading Models {} and {}'.format( os.path.join(save_dir, "encoder.pt"), os.path.join(save_dir, "decoder.pt"))) encoder = torch.load(os.path.join(save_dir, "encoder.pt"), map_location='cpu') decoder = torch.load(os.path.join(save_dir, "decoder.pt"), map_location='cpu') # encoder.load_state_dict(encoder_dict) # decoder.load_state_dict(decoder_dict) if torch.cuda.is_available(): encoder = encoder.cuda() decoder = decoder.cuda() dataloader = DataLoader(datasets.TestDataset(input_size=input_size), batch_size=1, shuffle=False, num_workers=8) evaluate(encoder, decoder, dataloader, save_dir=save_dir)