def main(config): cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir( config.exp_name) if config.dataset == 'vg': data_loader, _ = get_dataloader_vg(batch_size=config.batch_size, VG_DIR=config.vg_dir) elif config.dataset == 'coco': data_loader, _ = get_dataloader_coco(batch_size=config.batch_size, COCO_DIR=config.coco_dir) vocab_num = data_loader.dataset.num_objects assert config.clstm_layers > 0 netG = Generator(num_embeddings=vocab_num, embedding_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers).to(device) netD_image = ImageDiscriminator(conv_dim=config.embedding_dim).to(device) netD_object = ObjectDiscriminator(n_class=vocab_num).to(device) netD_image = add_sn(netD_image) netD_object = add_sn(netD_object) netG_optimizer = torch.optim.Adam(netG.parameters(), config.learning_rate, [0.5, 0.999]) netD_image_optimizer = torch.optim.Adam(netD_image.parameters(), config.learning_rate, [0.5, 0.999]) netD_object_optimizer = torch.optim.Adam(netD_object.parameters(), config.learning_rate, [0.5, 0.999]) start_iter = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) _ = load_model(netD_image, model_dir=model_save_dir, appendix='netD_image', iter=config.resume_iter) _ = load_model(netD_object, model_dir=model_save_dir, appendix='netD_object', iter=config.resume_iter) data_iter = iter(data_loader) if start_iter < config.niter: if config.use_tensorboard: writer = SummaryWriter(log_save_dir) for i in range(start_iter, config.niter): try: batch = next(data_iter) except: data_iter = iter(data_loader) batch = next(data_iter) # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # imgs, objs, boxes, masks, obj_to_img = batch z = torch.randn(objs.size(0), config.z_dim) imgs, objs, boxes, masks, obj_to_img, z = imgs.to(device), objs.to(device), boxes.to(device), \ masks.to(device), obj_to_img, z.to(device) # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z) crops_input, crops_input_rec, crops_rand, img_rec, img_rand, mu, logvar, z_rand_rec = output # Compute image adv loss with fake images. out_logits = netD_image(img_rec.detach()) d_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) out_logits = netD_image(img_rand.detach()) d_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) d_image_adv_loss_fake = 0.5 * d_image_adv_loss_fake_rec + 0.5 * d_image_adv_loss_fake_rand # Compute image src loss with real images rec. out_logits = netD_image(imgs) d_image_adv_loss_real = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) # Compute object sn adv loss with fake rec crops out_logits, _ = netD_object(crops_input_rec.detach(), objs) g_object_adv_loss_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # Compute object sn adv loss with fake rand crops out_logits, _ = netD_object(crops_rand.detach(), objs) d_object_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) d_object_adv_loss_fake = 0.5 * g_object_adv_loss_rec + 0.5 * d_object_adv_loss_fake_rand # Compute object sn adv loss with real crops. out_logits_src, out_logits_cls = netD_object( crops_input.detach(), objs) d_object_adv_loss_real = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) d_object_cls_loss_real = F.cross_entropy(out_logits_cls, objs) # Backward and optimizloe. d_loss = 0 d_loss += config.lambda_img_adv * (d_image_adv_loss_fake + d_image_adv_loss_real) d_loss += config.lambda_obj_adv * (d_object_adv_loss_fake + d_object_adv_loss_real) d_loss += config.lambda_obj_cls * d_object_cls_loss_real netD_image.zero_grad() netD_object.zero_grad() d_loss.backward() netD_image_optimizer.step() netD_object_optimizer.step() # Logging. loss = {} loss['D/loss'] = d_loss.item() loss['D/image_adv_loss_real'] = d_image_adv_loss_real.item() loss['D/image_adv_loss_fake'] = d_image_adv_loss_fake.item() loss['D/object_adv_loss_real'] = d_object_adv_loss_real.item() loss['D/object_adv_loss_fake'] = d_object_adv_loss_fake.item() loss['D/object_cls_loss_real'] = d_object_cls_loss_real.item() # =================================================================================== # # 3. Train the generator # # =================================================================================== # # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z) crops_input, crops_input_rec, crops_rand, img_rec, img_rand, mu, logvar, z_rand_rec = output # reconstruction loss of ae and img g_img_rec_loss = torch.abs(img_rec - imgs).mean() g_z_rec_loss = torch.abs(z_rand_rec - z).mean() # kl loss kl_element = mu.pow(2).add_( logvar.exp()).mul_(-1).add_(1).add_(logvar) g_kl_loss = torch.sum(kl_element).mul_(-0.5) # Compute image adv loss with fake images. out_logits = netD_image(img_rec) g_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) out_logits = netD_image(img_rand) g_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) g_image_adv_loss_fake = 0.5 * g_image_adv_loss_fake_rec + 0.5 * g_image_adv_loss_fake_rand # Compute object adv loss with fake images. out_logits_src, out_logits_cls = netD_object(crops_input_rec, objs) g_object_adv_loss_rec = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_rec = F.cross_entropy(out_logits_cls, objs) out_logits_src, out_logits_cls = netD_object(crops_rand, objs) g_object_adv_loss_rand = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_rand = F.cross_entropy(out_logits_cls, objs) g_object_adv_loss = 0.5 * g_object_adv_loss_rec + 0.5 * g_object_adv_loss_rand g_object_cls_loss = 0.5 * g_object_cls_loss_rec + 0.5 * g_object_cls_loss_rand # Backward and optimize. g_loss = 0 g_loss += config.lambda_img_rec * g_img_rec_loss g_loss += config.lambda_z_rec * g_z_rec_loss g_loss += config.lambda_img_adv * g_image_adv_loss_fake g_loss += config.lambda_obj_adv * g_object_adv_loss g_loss += config.lambda_obj_cls * g_object_cls_loss g_loss += config.lambda_kl * g_kl_loss netG.zero_grad() g_loss.backward() netG_optimizer.step() loss['G/loss'] = g_loss.item() loss['G/image_adv_loss'] = g_image_adv_loss_fake.item() loss['G/object_adv_loss'] = g_object_adv_loss.item() loss['G/object_cls_loss'] = g_object_cls_loss.item() loss['G/rec_img'] = g_img_rec_loss.item() loss['G/rec_z'] = g_z_rec_loss.item() loss['G/kl'] = g_kl_loss.item() # =================================================================================== # # 4. Log # # =================================================================================== # if (i + 1) % config.log_step == 0: log = 'iter [{:06d}/{:06d}]'.format(i + 1, config.niter) for tag, roi_value in loss.items(): log += ", {}: {:.4f}".format(tag, roi_value) print(log) if (i + 1 ) % config.tensorboard_step == 0 and config.use_tensorboard: for tag, roi_value in loss.items(): writer.add_scalar(tag, roi_value, i + 1) writer.add_image( 'Result/crop_real', imagenet_deprocess_batch(crops_input).float() / 255, i + 1) writer.add_image( 'Result/crop_real_rec', imagenet_deprocess_batch(crops_input_rec).float() / 255, i + 1) writer.add_image( 'Result/crop_rand', imagenet_deprocess_batch(crops_rand).float() / 255, i + 1) writer.add_image('Result/img_real', imagenet_deprocess_batch(imgs).float() / 255, i + 1) writer.add_image( 'Result/img_real_rec', imagenet_deprocess_batch(img_rec).float() / 255, i + 1) writer.add_image( 'Result/img_fake_rand', imagenet_deprocess_batch(img_rand).float() / 255, i + 1) if (i + 1) % config.save_step == 0: save_model(netG, model_dir=model_save_dir, appendix='netG', iter=i + 1, save_num=5, save_step=config.save_step) save_model(netD_image, model_dir=model_save_dir, appendix='netD_image', iter=i + 1, save_num=5, save_step=config.save_step) save_model(netD_object, model_dir=model_save_dir, appendix='netD_object', iter=i + 1, save_num=5, save_step=config.save_step) if config.use_tensorboard: writer.close()
def main(config): matrix = torch.load("matrix_obj_vs_att.pt") cudnn.benchmark = True device = torch.device('cuda:1') log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir( config.exp_name) attribute_nums = 106 data_loader, _ = get_dataloader_vg(batch_size=config.batch_size, attribute_embedding=attribute_nums, image_size=config.image_size) vocab_num = data_loader.dataset.num_objects if config.clstm_layers == 0: netG = Generator_nolstm(num_embeddings=vocab_num, embedding_dim=config.embedding_dim, z_dim=config.z_dim).to(device) else: netG = Generator(num_embeddings=vocab_num, obj_att_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers, obj_size=config.object_size, attribute_dim=attribute_nums).to(device) netD_image = ImageDiscriminator(conv_dim=config.embedding_dim).to(device) netD_object = ObjectDiscriminator(n_class=vocab_num).to(device) netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device) netD_image = add_sn(netD_image) netD_object = add_sn(netD_object) netD_att = add_sn(netD_att) netG_optimizer = torch.optim.Adam(netG.parameters(), config.learning_rate, [0.5, 0.999]) netD_image_optimizer = torch.optim.Adam(netD_image.parameters(), config.learning_rate, [0.5, 0.999]) netD_object_optimizer = torch.optim.Adam(netD_object.parameters(), config.learning_rate, [0.5, 0.999]) netD_att_optimizer = torch.optim.Adam(netD_att.parameters(), config.learning_rate, [0.5, 0.999]) start_iter_ = load_model(netD_object, model_dir=model_save_dir, appendix='netD_object', iter=config.resume_iter) start_iter_ = load_model(netD_att, model_dir=model_save_dir, appendix='netD_attribute', iter=config.resume_iter) start_iter_ = load_model(netD_image, model_dir=model_save_dir, appendix='netD_image', iter=config.resume_iter) start_iter = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) data_iter = iter(data_loader) if start_iter < config.niter: if config.use_tensorboard: writer = SummaryWriter(log_save_dir) for i in range(start_iter, config.niter): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # try: batch = next(data_iter) except: data_iter = iter(data_loader) batch = next(data_iter) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch z = torch.randn(objs.size(0), config.z_dim) att_idx = attribute.sum(dim=1).nonzero().squeeze() # print("Train D") # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift \ = imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), obj_to_img, z.to( device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device) attribute_GT = attribute.clone() # estimate attributes attribute_est = attribute.clone() att_mask = torch.zeros(attribute.shape[0]) att_mask = att_mask.scatter(0, att_idx, 1).to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, config.object_size) estimated_att = netD_att(crops_input) max_idx = estimated_att.argmax(1) max_idx = max_idx.float() * (~att_mask.byte()).float().to(device) for row in range(attribute.shape[0]): if row not in att_idx: attribute_est[row, int(max_idx[row])] = 1 # change GT attribute: num_img_to_change = math.floor(imgs.shape[0] / 3) for img_idx in range(num_img_to_change): obj_indices = torch.nonzero(obj_to_img == img_idx).view(-1) num_objs_to_change = math.floor(len(obj_indices) / 2) for changed, obj_idx in enumerate(obj_indices): if changed >= num_objs_to_change: break obj = objs[obj_idx] # change GT attribute old_attributes = torch.nonzero( attribute_GT[obj_idx]).view(-1) new_attribute = random.choices(range(106), matrix[obj].scatter( 0, old_attributes.cpu(), 0), k=random.randrange(1, 3)) attribute[obj_idx] = 0 # remove all attributes for obj attribute[obj_idx] = attribute[obj_idx].scatter( 0, torch.LongTensor(new_attribute).to(device), 1) # assign new attribute # change estimated attributes attribute_est[obj_idx] = 0 # remove all attributes for obj attribute_est[obj_idx] = attribute[obj_idx].scatter( 0, torch.LongTensor(new_attribute).to(device), 1) # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output # Compute image adv loss with fake images. out_logits = netD_image(img_rec.detach()) d_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) out_logits = netD_image(img_rand.detach()) d_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # shift image adv loss out_logits = netD_image(img_shift.detach()) d_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) d_image_adv_loss_fake = 0.4 * d_image_adv_loss_fake_rec + 0.4 * d_image_adv_loss_fake_rand + 0.2 * d_image_adv_loss_fake_shift # Compute image src loss with real images rec. out_logits = netD_image(imgs) d_image_adv_loss_real = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) # Compute object sn adv loss with fake rec crops out_logits, _ = netD_object(crops_input_rec.detach(), objs) g_object_adv_loss_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # Compute object sn adv loss with fake rand crops out_logits, _ = netD_object(crops_rand.detach(), objs) d_object_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # shift obj adv loss out_logits, _ = netD_object(crops_shift.detach(), objs) d_object_adv_loss_fake_shift = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) d_object_adv_loss_fake = 0.4 * g_object_adv_loss_rec + 0.4 * d_object_adv_loss_fake_rand + 0.2 * d_object_adv_loss_fake_shift # Compute object sn adv loss with real crops. out_logits_src, out_logits_cls = netD_object( crops_input.detach(), objs) d_object_adv_loss_real = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) # cls d_object_cls_loss_real = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_input.detach()) att_idx = attribute_GT.sum(dim=1).nonzero().squeeze() att_cls_annotated = torch.index_select(att_cls, 0, att_idx) attribute_annotated = torch.index_select(attribute_GT, 0, att_idx) d_object_att_cls_loss_real = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) # Backward and optimize. d_loss = 0 d_loss += config.lambda_img_adv * (d_image_adv_loss_fake + d_image_adv_loss_real) d_loss += config.lambda_obj_adv * (d_object_adv_loss_fake + d_object_adv_loss_real) d_loss += config.lambda_obj_cls * d_object_cls_loss_real d_loss += config.lambda_att_cls * d_object_att_cls_loss_real netD_image.zero_grad() netD_object.zero_grad() netD_att.zero_grad() d_loss.backward() netD_image_optimizer.step() netD_object_optimizer.step() netD_att_optimizer.step() # Logging. loss = {} loss['D/loss'] = d_loss.item() loss['D/image_adv_loss_real'] = d_image_adv_loss_real.item() loss['D/image_adv_loss_fake'] = d_image_adv_loss_fake.item() loss['D/object_adv_loss_real'] = d_object_adv_loss_real.item() loss['D/object_adv_loss_fake'] = d_object_adv_loss_fake.item() loss['D/object_cls_loss_real'] = d_object_cls_loss_real.item() loss['D/object_att_cls_loss'] = d_object_att_cls_loss_real.item() # print("train G") # =================================================================================== # # 3. Train the generator # # =================================================================================== # # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output # reconstruction loss of ae and img rec_img_mask = torch.ones(imgs.shape[0]).scatter( 0, torch.LongTensor(range(num_img_to_change)), 0).to(device) g_img_rec_loss = rec_img_mask * torch.abs(img_rec - imgs).view( imgs.shape[0], -1).mean(1) g_img_rec_loss = g_img_rec_loss.sum() / (imgs.shape[0] - num_img_to_change) g_z_rec_loss_rand = torch.abs(z_rand_rec - z).mean() g_z_rec_loss_shift = torch.abs(z_rand_shift - z).mean() g_z_rec_loss = 0.5 * g_z_rec_loss_rand + 0.5 * g_z_rec_loss_shift # kl loss kl_element = mu.pow(2).add_( logvar.exp()).mul_(-1).add_(1).add_(logvar) g_kl_loss = torch.sum(kl_element).mul_(-0.5) # Compute image adv loss with fake images. out_logits = netD_image(img_rec) g_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) out_logits = netD_image(img_rand) g_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) # shift image adv loss out_logits = netD_image(img_shift) g_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) g_image_adv_loss_fake = 0.4 * g_image_adv_loss_fake_rec + 0.4 * g_image_adv_loss_fake_rand + 0.2 * g_image_adv_loss_fake_shift # Compute object adv loss with fake images. out_logits_src, out_logits_cls = netD_object(crops_input_rec, objs) g_object_adv_loss_rec = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_rec = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_input_rec) att_idx = attribute.sum(dim=1).nonzero().squeeze() attribute_annotated = torch.index_select(attribute, 0, att_idx) att_cls_annotated = torch.index_select(att_cls, 0, att_idx) g_object_att_cls_loss_rec = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) out_logits_src, out_logits_cls = netD_object(crops_rand, objs) g_object_adv_loss_rand = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_rand = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_rand) att_cls_annotated = torch.index_select(att_cls, 0, att_idx) g_object_att_cls_loss_rand = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) # shift adv obj loss out_logits_src, out_logits_cls = netD_object(crops_shift, objs) g_object_adv_loss_shift = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_shift = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_shift) att_cls_annotated = torch.index_select(att_cls, 0, att_idx) g_object_att_cls_loss_shift = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) g_object_att_cls_loss = 0.4 * g_object_att_cls_loss_rec + 0.4 * g_object_att_cls_loss_rand + 0.2 * g_object_att_cls_loss_shift g_object_adv_loss = 0.4 * g_object_adv_loss_rec + 0.4 * g_object_adv_loss_rand + 0.2 * g_object_adv_loss_shift g_object_cls_loss = 0.4 * g_object_cls_loss_rec + 0.4 * g_object_cls_loss_rand + 0.2 * g_object_cls_loss_shift # Backward and optimize. g_loss = 0 g_loss += config.lambda_img_rec * g_img_rec_loss g_loss += config.lambda_z_rec * g_z_rec_loss g_loss += config.lambda_img_adv * g_image_adv_loss_fake g_loss += config.lambda_obj_adv * g_object_adv_loss g_loss += config.lambda_obj_cls * g_object_cls_loss g_loss += config.lambda_att_cls * g_object_att_cls_loss g_loss += config.lambda_kl * g_kl_loss netG.zero_grad() g_loss.backward() netG_optimizer.step() loss['G/loss'] = g_loss.item() loss['G/image_adv_loss'] = g_image_adv_loss_fake.item() loss['G/object_adv_loss'] = g_object_adv_loss.item() loss['G/object_cls_loss'] = g_object_cls_loss.item() loss['G/rec_img'] = g_img_rec_loss.item() loss['G/rec_z'] = g_z_rec_loss.item() loss['G/kl'] = g_kl_loss.item() loss['G/object_att_cls_loss'] = g_object_att_cls_loss.item() # =================================================================================== # # 4. Log # # =================================================================================== # if (i + 1) % config.log_step == 0: log = 'iter [{:06d}/{:06d}]'.format(i + 1, config.niter) for tag, roi_value in loss.items(): log += ", {}: {:.4f}".format(tag, roi_value) print(log) if (i + 1 ) % config.tensorboard_step == 0 and config.use_tensorboard: for tag, roi_value in loss.items(): writer.add_scalar(tag, roi_value, i + 1) writer.add_images( 'Result/crop_real', imagenet_deprocess_batch(crops_input).float() / 255, i + 1) writer.add_images( 'Result/crop_real_rec', imagenet_deprocess_batch(crops_input_rec).float() / 255, i + 1) writer.add_images( 'Result/crop_rand', imagenet_deprocess_batch(crops_rand).float() / 255, i + 1) writer.add_images('Result/img_real', imagenet_deprocess_batch(imgs).float() / 255, i + 1) writer.add_images( 'Result/img_real_rec', imagenet_deprocess_batch(img_rec).float() / 255, i + 1) writer.add_images( 'Result/img_fake_rand', imagenet_deprocess_batch(img_rand).float() / 255, i + 1) if (i + 1) % config.save_step == 0: # netG_noDP.load_state_dict(new_state_dict) save_model(netG, model_dir=model_save_dir, appendix='netG', iter=i + 1, save_num=2, save_step=config.save_step) save_model(netD_image, model_dir=model_save_dir, appendix='netD_image', iter=i + 1, save_num=2, save_step=config.save_step) save_model(netD_object, model_dir=model_save_dir, appendix='netD_object', iter=i + 1, save_num=2, save_step=config.save_step) save_model(netD_att, model_dir=model_save_dir, appendix='netD_attribute', iter=i + 1, save_num=2, save_step=config.save_step) if config.use_tensorboard: writer.close()
def main(config): cudnn.benchmark = True device = torch.device('cuda:1') log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir( config.exp_name) with open("data/vocab.json", 'r') as f: vocab = json.load(f) att_idx_to_name = np.array(vocab['attribute_idx_to_name']) print(att_idx_to_name) object_idx_to_name = np.array(vocab['object_idx_to_name']) attribute_nums = 106 train_data_loader, val_data_loader = get_dataloader_vg( batch_size=config.batch_size, attribute_embedding=attribute_nums) vocab_num = train_data_loader.dataset.num_objects netG = Generator(num_embeddings=vocab_num, obj_att_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers, obj_size=config.object_size, attribute_dim=attribute_nums).to(device) netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device) netD_att = add_sn(netD_att) start_iter_ = load_model(netD_att, model_dir="~/models/trained_models", appendix='netD_attribute', iter=config.resume_iter) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) data_loader = val_data_loader data_iter = iter(data_loader) cur_obj_start_idx = 0 attributes_pred = np.zeros((253468, attribute_nums)) attributes_gt = None with torch.no_grad(): netG.eval() for i, batch in enumerate(data_iter): print('batch {}'.format(i)) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch z = torch.randn(objs.size(0), config.z_dim) att_idx = attribute.sum(dim=1).nonzero().squeeze() imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \ imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \ obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device) # estimate attributes attribute_est = attribute.clone() att_mask = torch.zeros(attribute.shape[0]) att_mask = att_mask.scatter(0, att_idx, 1).to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, config.object_size) estimated_att = netD_att(crops_input) max_idx = estimated_att.argmax(1) max_idx = max_idx.float() * (~att_mask.byte()).float().to(device) for row in range(attribute.shape[0]): if row not in att_idx: attribute_est[row, int(max_idx[row])] = 1 # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output # predict attribute on generated crops_rand_yes = torch.index_select(crops_rand, 0, att_idx.to(device)) attribute_yes = torch.index_select(attribute, 0, att_idx.to(device)) estimated_att_rand = netD_att(crops_rand_yes) att_cls = torch.sigmoid(estimated_att_rand) for k in att_idx: ll = att_idx.cpu().numpy().tolist().index(k) non0_idx_cls = (att_cls[ll] > 0.9).nonzero() pred_idx = non0_idx_cls.squeeze().cpu().numpy() attributes_pred[cur_obj_start_idx, pred_idx] = 1 cur_obj_start_idx += 1 # construct GT array attributes_gt = attribute_yes.clone().cpu( ) if attributes_gt is None else np.vstack( [attributes_gt, attribute_yes.clone().cpu()]) img_rand = imagenet_deprocess_batch(img_rand) img_shift = imagenet_deprocess_batch(img_shift) img_rec = imagenet_deprocess_batch(img_rec) # attribute modification changed_list = [] src = 2 # 94 blue, 95 black tgt = 95 # 8 red, 2 white for idx, o in enumerate(objs): attribute[idx, [2, 8, 0, 94, 90, 95, 96, 34, 25, 70, 58, 104 ]] = 0 # remove other color attribute[idx, tgt] = 1 attribute_est[idx, [2, 8, 0, 94, 90, 95, 96, 34, 25, 70, 58, 104 ]] = 0 # remove other color attribute_est[idx, tgt] = 1 changed_list.append(idx) # Generate red image z = torch.randn(objs.size(0), config.z_dim).to(device) output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand_y, crops_shift_y, img_rec_y, img_rand_y, img_shift_y, mu, logvar, z_rand_rec, z_rand_shift = output img_rand_y = imagenet_deprocess_batch(img_rand_y) img_shift_y = imagenet_deprocess_batch(img_shift_y) img_rec_y = imagenet_deprocess_batch(img_rec_y) imgs = imagenet_deprocess_batch(imgs) # 2 top k estimated_att = netD_att(crops_rand) max_idx = estimated_att.topk(5)[1] changed_list = [i for i in changed_list if tgt not in max_idx[i]] estimated_att_y = netD_att(crops_rand_y) max_idx_y = estimated_att_y.topk(3)[1] changed_list_success = [ i for i in changed_list if tgt in max_idx_y[i] ] for j in range(imgs.shape[0]): img_np = img_shift[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_shift.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) img_np = img_rand[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_rand.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) img_rec_np = img_rec[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_rec.png'.format(i * config.batch_size + j)) imwrite(img_path, img_rec_np) img_real_np = imgs[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_real.png'.format(i * config.batch_size + j)) imwrite(img_path, img_real_np) cur_obj_success = [ int(objs[c]) for c in changed_list_success if obj_to_img[c] == j ] # save successfully modified images if len(cur_obj_success) > 0: img_shift_y = img_shift_y[j].numpy().transpose(1, 2, 0) img_path_red = os.path.join( result_save_dir, 'img{:06d}_shift_{}_modified.png'.format( i * config.batch_size + j, object_idx_to_name[cur_obj_success])) imwrite(img_path_red, img_shift_y) img_rec_np = img_rec_y[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_rec_modified.png'.format( i * config.batch_size + j, object_idx_to_name[cur_obj_success])) imwrite(img_path, img_rec_np) img_rand_np = img_rand_y[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_rand_modified.png'.format( i * config.batch_size + j, object_idx_to_name[cur_obj_success])) imwrite(img_path, img_rand_np) # calculate recall precision num_data = attributes_gt.shape[0] count = np.zeros((num_data, 4)) recall, precision = np.zeros((num_data)), np.zeros((num_data)) for i in range(num_data): # tn, fp, fn, tp = confusion_matrix(attributes_pred, attributes_gt).ravel() count[i] = confusion_matrix(attributes_gt[i], attributes_pred[i]).ravel() recall[i] = count[i][3] / (count[i][3] + count[i][2]) precision[i] = safe_division(count[i][3], count[i][3] + count[i][1]) print("average precision = {}".format(precision.mean())) print("average recall = {}".format(recall.mean())) print("average pred # per obj") print((count[:, 1].sum() + count[:, 3].sum()) / count.shape[0]) print("average GT # per obj") print((count[:, 2].sum() + count[:, 3].sum()) / count.shape[0]) print("% of data that predict something") print(((count[:, 1] + count[:, 3]) > 0).sum() / count.shape[0]) print("% of data at least predicted correct once") print((count[:, 3] > 0).sum() / count.shape[0])
def main(config): cudnn.benchmark = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") result_save_dir = config.results_dir if not Path(result_save_dir).exists(): Path(result_save_dir).mkdir(parents=True) if config.dataset == 'vg': train_data_loader, val_data_loader = get_dataloader_vg( batch_size=config.batch_size, VG_DIR=config.vg_dir) deprocess_batch = imagenet_deprocess_batch elif config.dataset == 'coco': train_data_loader, val_data_loader = get_dataloader_coco( batch_size=config.batch_size, COCO_DIR=config.coco_dir) deprocess_batch = imagenet_deprocess_batch elif config.dataset.startswith('shapenet'): dset_id = config.dataset.split('_')[-1] train_data_loader, val_data_loader = get_dataloader_shapenet( batch_size=config.batch_size, dset_id=dset_id, DATA_DIR=config.shapenet_dir, testset=True) deprocess_batch = lambda x: shapenet_deprocess_batch( x, dset=data_loader.dataset) vocab_num = train_data_loader.dataset.num_objects assert config.clstm_layers > 0 netG = Generator(num_embeddings=vocab_num, embedding_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers).to(device) print('load model from: {}'.format(config.saved_model)) netG.load_state_dict(torch.load(config.saved_model)) data_loader = val_data_loader data_iter = iter(data_loader) with torch.no_grad(): netG.eval() for i, batch in enumerate(data_iter): print('batch {}'.format(i)) imgs, objs, boxes, masks, obj_to_img = batch z = torch.randn(objs.size(0), config.z_dim) imgs, objs, boxes, masks, obj_to_img, z = imgs.to(device), objs.to( device), boxes.to(device), masks.to(device), obj_to_img, z.to( device) # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z) crops_input, crops_input_rec, crops_rand, img_rec, img_rand, mu, logvar, z_rand_rec = output img_rand = deprocess_batch(img_rand) # Save the generated images for j in range(img_rand.shape[0]): img_np = img_rand[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np)
def main(config): cudnn.benchmark = True device = torch.device('cuda:0') log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir( config.exp_name) resinet_dir = "~/pickle/128_vg_pkl_resinet50_247k" if not os.path.exists(resinet_dir): os.mkdir(resinet_dir) attribute_nums = 106 if config.dataset == 'vg': train_data_loader, val_data_loader = get_dataloader_vg( batch_size=config.batch_size, attribute_embedding=attribute_nums) elif config.dataset == 'coco': train_data_loader, val_data_loader = get_dataloader_coco( batch_size=config.batch_size) vocab_num = train_data_loader.dataset.num_objects if config.clstm_layers == 0: netG = Generator_nolstm(num_embeddings=vocab_num, embedding_dim=config.embedding_dim, z_dim=config.z_dim).to(device) else: netG = Generator(num_embeddings=vocab_num, obj_att_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers, obj_size=config.object_size, attribute_dim=attribute_nums).to(device) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) netD_att = train_att_change.AttributeDiscriminator( n_attribute=attribute_nums).to(device) netD_att = train_att_change.add_sn(netD_att) start_iter_ = load_model(netD_att, model_dir="~/models/trained_models", appendix='netD_attribute', iter=config.resume_iter) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) data_loader = val_data_loader data_iter = iter(data_loader) with torch.no_grad(): netG.eval() for i, batch in enumerate(data_iter): print('batch {}'.format(i)) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch att_idx = attribute.sum(dim=1).nonzero().squeeze() z = torch.randn(objs.size(0), config.z_dim) imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \ imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \ obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device) # estimate attributes attribute_est = attribute.clone() att_mask = torch.zeros(attribute.shape[0]) att_mask = att_mask.scatter(0, att_idx, 1).to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, config.object_size) estimated_att = netD_att(crops_input) max_idx = estimated_att.argmax(1) max_idx = max_idx.float() * (~att_mask.byte()).float().to(device) for row in range(attribute.shape[0]): if row not in att_idx: attribute_est[row, int(max_idx[row])] = 1 # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output dict_to_save = { 'imgs': imgs, 'imgs_rand': img_rand, 'imgs_shift': img_shift, 'objs': objs, 'boxes': boxes, 'boxes_shift': boxes_shift, 'obj_to_img': obj_to_img } out_name = os.path.join(resinet_dir, 'batch_{}.pkl'.format(i)) pickle.dump(dict_to_save, open(out_name, 'wb')) img_rand = imagenet_deprocess_batch(img_rand) img_shift = imagenet_deprocess_batch(img_shift) # Save the generated images for j in range(img_rand.shape[0]): # layout = draw_layout(boxes[obj_to_img==j], objs[obj_to_img==j], True, dset_mode='vg') # img_path = os.path.join(result_save_dir, 'layout/img{:06d}_layout.png'.format(i*config.batch_size+j)) # imwrite(img_path, layout) img_np = img_rand[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) # layout = draw_layout(boxes_shift[obj_to_img==j], objs[obj_to_img==j], True, dset_mode='vg') # img_path = os.path.join(result_save_dir, 'layout/img{:06d}_layouts.png'.format(i*config.batch_size+j)) # imwrite(img_path, layout) img_np = img_shift[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}s.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np)
niter = 400000 attribute_nums = 106 batch_size = 12 save_step = 2000 netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device) netD_att = add_sn(netD_att) start_iter = load_model(netD_att, model_dir=model_save_dir, appendix='netD_attribute', iter='l') netD_att_optimizer = torch.optim.Adam(netD_att.parameters(), 2e-4, [0.5, 0.999]) dataset = 'vg' if dataset == 'vg': train_data_loader, val_data_loader = get_dataloader_vg( batch_size=batch_size, attribute_embedding=attribute_nums) elif dataset == 'coco': train_data_loader, val_data_loader = get_dataloader_coco( batch_size=batch_size) data_loader = train_data_loader data_iter = iter(data_loader) if start_iter < niter: writer = SummaryWriter(log_save_dir) for i in range(start_iter, niter): try: batch = next(data_iter)
def main(config): cudnn.benchmark = True device = torch.device('cuda:0') log_save_dir, model_save_dir, sample_save_dir, result_save_dir1, result_save_dir2 = prepare_dir( config.exp_name) resinet_dir = "~/pickle/vg_pkl_resinet50" if not os.path.exists(resinet_dir): os.mkdir(resinet_dir) attribute_nums = 106 if config.dataset == 'vg': train_data_loader, val_data_loader = get_dataloader_vg( batch_size=config.batch_size, attribute_embedding=attribute_nums) elif config.dataset == 'coco': train_data_loader, val_data_loader = get_dataloader_coco( batch_size=config.batch_size) vocab_num = train_data_loader.dataset.num_objects if config.clstm_layers == 0: netG = Generator_nolstm(num_embeddings=vocab_num, embedding_dim=config.embedding_dim, z_dim=config.z_dim).to(device) else: netG = Generator(num_embeddings=vocab_num, obj_att_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers, obj_size=config.object_size, attribute_dim=attribute_nums).to(device) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) netD_att = train_att_change.AttributeDiscriminator( n_attribute=attribute_nums).to(device) netD_att = train_att_change.add_sn(netD_att) start_iter_ = load_model(netD_att, model_dir="~/models/trained_models", appendix='netD_attribute', iter=config.resume_iter) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) data_loader = val_data_loader data_iter = iter(data_loader) L1_dist_b = 0 L1_dist_f = 0 L1_rand_dist = 0 with torch.no_grad(): netG.eval() for i, batch in enumerate(data_iter): print('batch {}'.format(i)) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch att_idx = attribute.sum(dim=1).nonzero().squeeze() z = torch.randn(objs.size(0), config.z_dim) imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \ imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \ obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device) # estimate attributes attribute_est = attribute.clone() att_mask = torch.zeros(attribute.shape[0]) att_mask = att_mask.scatter(0, att_idx, 1).to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, config.object_size) estimated_att = netD_att(crops_input) max_idx = estimated_att.argmax(1) max_idx = max_idx.float() * (~att_mask.byte()).float().to(device) for row in range(attribute.shape[0]): if row not in att_idx: attribute_est[row, int(max_idx[row])] = 1 # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output foreground = list( map(lambda x: (x.data[2] - x.data[0]) < 0.5, boxes)) dict_to_save = { 'imgs': imgs, 'imgs_rand': img_rand, 'imgs_shift': img_shift, 'objs': objs, 'boxes': boxes, 'boxes_shift': boxes_shift, 'obj_to_img': obj_to_img, 'foreground': foreground } out_name = os.path.join(resinet_dir, 'batch_{}.pkl'.format(i)) pickle.dump(dict_to_save, open(out_name, 'wb')) img_rand = imagenet_deprocess_batch(img_rand).type(torch.int32) imgs = imagenet_deprocess_batch(imgs).type(torch.int32) img_shift = imagenet_deprocess_batch(img_shift).type(torch.int32) # Save the generated images fore_count = 0 background_dist = 0 foreground_dist = 0 rand_diff = 0 for j in range(img_rand.shape[0]): obj_indices = torch.nonzero(obj_to_img == j).view(-1) foreground_obj_indices = [ int(i) for i in obj_indices if foreground[i] == 1 ] if foreground_obj_indices != []: foreground_mask = torch.max( torch.max(masks[foreground_obj_indices], 0)[0], torch.max(masks_shift[foreground_obj_indices], 0)[0]).byte() else: print("no foreground") foreground_mask = torch.zeros(1, 64, 64).byte().to(device) background_dist += safe_division( torch.abs((img_rand[j] - img_shift[j]) * (~foreground_mask).cpu().int()).sum(), (3 * (~foreground_mask).cpu().float().sum())) img_np = (img_rand[j] * (~foreground_mask).int().cpu()).numpy().transpose( 1, 2, 0) img_path = os.path.join( result_save_dir1, 'img{:06d}.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) img_np = (img_shift[j] * (~foreground_mask).int().cpu()).numpy().transpose( 1, 2, 0) img_path = os.path.join( result_save_dir2, 'img{:06d}.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) for foreground_i in foreground_obj_indices: try: foreground_dist += torch.abs( torch.masked_select( img_rand[j], masks[foreground_i].cpu().byte()) - torch.masked_select( img_shift[j], masks_shift[foreground_i].cpu(). byte())).sum() / (3 * (masks[foreground_i]).sum()) fore_count += 1 except: continue rand_diff = torch.abs(img_rand[j] - img_shift[j - 1]).sum() / (3 * 64 * 64) L1_dist_b += background_dist / config.batch_size L1_dist_f += safe_division(foreground_dist, fore_count) L1_rand_dist += rand_diff try: print("runing b_dist = %d, f_dist = %d, rand_diff = %d" % (background_dist / config.batch_size, safe_division(foreground_dist, fore_count), rand_diff)) except: print(math.isnan(background_dist / config.batch_size), math.isnan(safe_division(foreground_dist, fore_count)), math.isnan(rand_diff)) print("Background L1 dist = %f" % (L1_dist_b / i)) print("Foreground L1 dist = %f" % (L1_dist_f / i)) print("Rand L1 dist = %f" % (L1_rand_dist / i))