def forward(self, imgs, objs, boxes, masks, obj_to_img, z_rand, attribute, masks_shift, boxes_shift, attribute_est): crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, self.obj_size) z_rec, mu, logvar = self.crop_encoder(crops_input, objs) objs_att = self.attribute_encoder(objs, attribute) objs_att_est = self.attribute_encoder(objs, attribute_est) # (n, clstm_dim*2, 8, 8) h_rec = self.layout_encoder(objs_att_est, masks, obj_to_img, z_rec, objs) h_rand = self.layout_encoder(objs_att, masks, obj_to_img, z_rand, objs) h_shift = self.layout_encoder(objs_att, masks_shift, obj_to_img, z_rand, objs) # global context encoder h_rec_global = self.global_encoder(h_rec) h_rand_global = self.global_encoder(h_rand) h_shift_global = self.global_encoder(h_shift) img_rec = self.decoder(h_rec, h_rec_global) img_rand = self.decoder(h_rand, h_rand_global) img_shift = self.decoder(h_shift, h_shift_global) crops_rand = crop_bbox_batch(img_rand, boxes, obj_to_img, self.obj_size) _, z_rand_rec, _ = self.crop_encoder(crops_rand, objs) crops_input_rec = crop_bbox_batch(img_rec, boxes, obj_to_img, self.obj_size) crops_shift = crop_bbox_batch(img_shift, boxes_shift, obj_to_img, self.obj_size) _, z_rand_shift, _ = self.crop_encoder(crops_shift, objs) return crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift
def forward(self, imgs, objs, boxes, masks, obj_to_img, z_rand): crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, self.obj_size) z_rec, mu, logvar = self.crop_encoder(crops_input, objs) # (n, clstm_dim*2, 8, 8) h_rec = self.layout_encoder(objs, masks, obj_to_img, z_rec) h_rand = self.layout_encoder(objs, masks, obj_to_img, z_rand) img_rec = self.decoder(h_rec) img_rand = self.decoder(h_rand) crops_rand = crop_bbox_batch(img_rand, boxes, obj_to_img, self.obj_size) _, z_rand_rec, _ = self.crop_encoder(crops_rand, objs) crops_input_rec = crop_bbox_batch(img_rec, boxes, obj_to_img, self.obj_size) return crops_input, crops_input_rec, crops_rand, img_rec, img_rand, mu, logvar, z_rand_rec
def test_model(model, dataloaders, criterion, input_size=224): since = time.time() val_acc_history = [] phase = 'val' model.eval() running_loss = 0.0 running_corrects = 0 running_count = 0 # Iterate over data. for i, batch in enumerate(dataloaders[phase]): imgs, objs, boxes, masks, obj_to_img, attributes = batch inputs = crop_bbox_batch(imgs, boxes, obj_to_img, input_size) inputs = inputs.to(device) labels = objs.to(device) # forward # track history if only in train with torch.no_grad(): outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) # statistics running_loss += loss.item() * labels.size(0) running_corrects += torch.sum(preds == labels.data) running_count += labels.size(0) if (i + 1) % 20 == 0: print('loss: {:.4f} accu: {:.4f}'.format(loss.item(), torch.mean((preds == labels.data).float()))) epoch_loss = running_loss / running_count epoch_acc = running_corrects.double() / running_count print('================================================================') print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
def main(config): matrix = torch.load("matrix_obj_vs_att.pt") cudnn.benchmark = True device = torch.device('cuda:1') log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir( config.exp_name) attribute_nums = 106 data_loader, _ = get_dataloader_vg(batch_size=config.batch_size, attribute_embedding=attribute_nums, image_size=config.image_size) vocab_num = data_loader.dataset.num_objects if config.clstm_layers == 0: netG = Generator_nolstm(num_embeddings=vocab_num, embedding_dim=config.embedding_dim, z_dim=config.z_dim).to(device) else: netG = Generator(num_embeddings=vocab_num, obj_att_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers, obj_size=config.object_size, attribute_dim=attribute_nums).to(device) netD_image = ImageDiscriminator(conv_dim=config.embedding_dim).to(device) netD_object = ObjectDiscriminator(n_class=vocab_num).to(device) netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device) netD_image = add_sn(netD_image) netD_object = add_sn(netD_object) netD_att = add_sn(netD_att) netG_optimizer = torch.optim.Adam(netG.parameters(), config.learning_rate, [0.5, 0.999]) netD_image_optimizer = torch.optim.Adam(netD_image.parameters(), config.learning_rate, [0.5, 0.999]) netD_object_optimizer = torch.optim.Adam(netD_object.parameters(), config.learning_rate, [0.5, 0.999]) netD_att_optimizer = torch.optim.Adam(netD_att.parameters(), config.learning_rate, [0.5, 0.999]) start_iter_ = load_model(netD_object, model_dir=model_save_dir, appendix='netD_object', iter=config.resume_iter) start_iter_ = load_model(netD_att, model_dir=model_save_dir, appendix='netD_attribute', iter=config.resume_iter) start_iter_ = load_model(netD_image, model_dir=model_save_dir, appendix='netD_image', iter=config.resume_iter) start_iter = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) data_iter = iter(data_loader) if start_iter < config.niter: if config.use_tensorboard: writer = SummaryWriter(log_save_dir) for i in range(start_iter, config.niter): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # try: batch = next(data_iter) except: data_iter = iter(data_loader) batch = next(data_iter) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch z = torch.randn(objs.size(0), config.z_dim) att_idx = attribute.sum(dim=1).nonzero().squeeze() # print("Train D") # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift \ = imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), obj_to_img, z.to( device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device) attribute_GT = attribute.clone() # estimate attributes attribute_est = attribute.clone() att_mask = torch.zeros(attribute.shape[0]) att_mask = att_mask.scatter(0, att_idx, 1).to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, config.object_size) estimated_att = netD_att(crops_input) max_idx = estimated_att.argmax(1) max_idx = max_idx.float() * (~att_mask.byte()).float().to(device) for row in range(attribute.shape[0]): if row not in att_idx: attribute_est[row, int(max_idx[row])] = 1 # change GT attribute: num_img_to_change = math.floor(imgs.shape[0] / 3) for img_idx in range(num_img_to_change): obj_indices = torch.nonzero(obj_to_img == img_idx).view(-1) num_objs_to_change = math.floor(len(obj_indices) / 2) for changed, obj_idx in enumerate(obj_indices): if changed >= num_objs_to_change: break obj = objs[obj_idx] # change GT attribute old_attributes = torch.nonzero( attribute_GT[obj_idx]).view(-1) new_attribute = random.choices(range(106), matrix[obj].scatter( 0, old_attributes.cpu(), 0), k=random.randrange(1, 3)) attribute[obj_idx] = 0 # remove all attributes for obj attribute[obj_idx] = attribute[obj_idx].scatter( 0, torch.LongTensor(new_attribute).to(device), 1) # assign new attribute # change estimated attributes attribute_est[obj_idx] = 0 # remove all attributes for obj attribute_est[obj_idx] = attribute[obj_idx].scatter( 0, torch.LongTensor(new_attribute).to(device), 1) # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output # Compute image adv loss with fake images. out_logits = netD_image(img_rec.detach()) d_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) out_logits = netD_image(img_rand.detach()) d_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # shift image adv loss out_logits = netD_image(img_shift.detach()) d_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) d_image_adv_loss_fake = 0.4 * d_image_adv_loss_fake_rec + 0.4 * d_image_adv_loss_fake_rand + 0.2 * d_image_adv_loss_fake_shift # Compute image src loss with real images rec. out_logits = netD_image(imgs) d_image_adv_loss_real = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) # Compute object sn adv loss with fake rec crops out_logits, _ = netD_object(crops_input_rec.detach(), objs) g_object_adv_loss_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # Compute object sn adv loss with fake rand crops out_logits, _ = netD_object(crops_rand.detach(), objs) d_object_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # shift obj adv loss out_logits, _ = netD_object(crops_shift.detach(), objs) d_object_adv_loss_fake_shift = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) d_object_adv_loss_fake = 0.4 * g_object_adv_loss_rec + 0.4 * d_object_adv_loss_fake_rand + 0.2 * d_object_adv_loss_fake_shift # Compute object sn adv loss with real crops. out_logits_src, out_logits_cls = netD_object( crops_input.detach(), objs) d_object_adv_loss_real = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) # cls d_object_cls_loss_real = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_input.detach()) att_idx = attribute_GT.sum(dim=1).nonzero().squeeze() att_cls_annotated = torch.index_select(att_cls, 0, att_idx) attribute_annotated = torch.index_select(attribute_GT, 0, att_idx) d_object_att_cls_loss_real = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) # Backward and optimize. d_loss = 0 d_loss += config.lambda_img_adv * (d_image_adv_loss_fake + d_image_adv_loss_real) d_loss += config.lambda_obj_adv * (d_object_adv_loss_fake + d_object_adv_loss_real) d_loss += config.lambda_obj_cls * d_object_cls_loss_real d_loss += config.lambda_att_cls * d_object_att_cls_loss_real netD_image.zero_grad() netD_object.zero_grad() netD_att.zero_grad() d_loss.backward() netD_image_optimizer.step() netD_object_optimizer.step() netD_att_optimizer.step() # Logging. loss = {} loss['D/loss'] = d_loss.item() loss['D/image_adv_loss_real'] = d_image_adv_loss_real.item() loss['D/image_adv_loss_fake'] = d_image_adv_loss_fake.item() loss['D/object_adv_loss_real'] = d_object_adv_loss_real.item() loss['D/object_adv_loss_fake'] = d_object_adv_loss_fake.item() loss['D/object_cls_loss_real'] = d_object_cls_loss_real.item() loss['D/object_att_cls_loss'] = d_object_att_cls_loss_real.item() # print("train G") # =================================================================================== # # 3. Train the generator # # =================================================================================== # # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output # reconstruction loss of ae and img rec_img_mask = torch.ones(imgs.shape[0]).scatter( 0, torch.LongTensor(range(num_img_to_change)), 0).to(device) g_img_rec_loss = rec_img_mask * torch.abs(img_rec - imgs).view( imgs.shape[0], -1).mean(1) g_img_rec_loss = g_img_rec_loss.sum() / (imgs.shape[0] - num_img_to_change) g_z_rec_loss_rand = torch.abs(z_rand_rec - z).mean() g_z_rec_loss_shift = torch.abs(z_rand_shift - z).mean() g_z_rec_loss = 0.5 * g_z_rec_loss_rand + 0.5 * g_z_rec_loss_shift # kl loss kl_element = mu.pow(2).add_( logvar.exp()).mul_(-1).add_(1).add_(logvar) g_kl_loss = torch.sum(kl_element).mul_(-0.5) # Compute image adv loss with fake images. out_logits = netD_image(img_rec) g_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) out_logits = netD_image(img_rand) g_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) # shift image adv loss out_logits = netD_image(img_shift) g_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) g_image_adv_loss_fake = 0.4 * g_image_adv_loss_fake_rec + 0.4 * g_image_adv_loss_fake_rand + 0.2 * g_image_adv_loss_fake_shift # Compute object adv loss with fake images. out_logits_src, out_logits_cls = netD_object(crops_input_rec, objs) g_object_adv_loss_rec = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_rec = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_input_rec) att_idx = attribute.sum(dim=1).nonzero().squeeze() attribute_annotated = torch.index_select(attribute, 0, att_idx) att_cls_annotated = torch.index_select(att_cls, 0, att_idx) g_object_att_cls_loss_rec = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) out_logits_src, out_logits_cls = netD_object(crops_rand, objs) g_object_adv_loss_rand = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_rand = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_rand) att_cls_annotated = torch.index_select(att_cls, 0, att_idx) g_object_att_cls_loss_rand = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) # shift adv obj loss out_logits_src, out_logits_cls = netD_object(crops_shift, objs) g_object_adv_loss_shift = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_shift = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_shift) att_cls_annotated = torch.index_select(att_cls, 0, att_idx) g_object_att_cls_loss_shift = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) g_object_att_cls_loss = 0.4 * g_object_att_cls_loss_rec + 0.4 * g_object_att_cls_loss_rand + 0.2 * g_object_att_cls_loss_shift g_object_adv_loss = 0.4 * g_object_adv_loss_rec + 0.4 * g_object_adv_loss_rand + 0.2 * g_object_adv_loss_shift g_object_cls_loss = 0.4 * g_object_cls_loss_rec + 0.4 * g_object_cls_loss_rand + 0.2 * g_object_cls_loss_shift # Backward and optimize. g_loss = 0 g_loss += config.lambda_img_rec * g_img_rec_loss g_loss += config.lambda_z_rec * g_z_rec_loss g_loss += config.lambda_img_adv * g_image_adv_loss_fake g_loss += config.lambda_obj_adv * g_object_adv_loss g_loss += config.lambda_obj_cls * g_object_cls_loss g_loss += config.lambda_att_cls * g_object_att_cls_loss g_loss += config.lambda_kl * g_kl_loss netG.zero_grad() g_loss.backward() netG_optimizer.step() loss['G/loss'] = g_loss.item() loss['G/image_adv_loss'] = g_image_adv_loss_fake.item() loss['G/object_adv_loss'] = g_object_adv_loss.item() loss['G/object_cls_loss'] = g_object_cls_loss.item() loss['G/rec_img'] = g_img_rec_loss.item() loss['G/rec_z'] = g_z_rec_loss.item() loss['G/kl'] = g_kl_loss.item() loss['G/object_att_cls_loss'] = g_object_att_cls_loss.item() # =================================================================================== # # 4. Log # # =================================================================================== # if (i + 1) % config.log_step == 0: log = 'iter [{:06d}/{:06d}]'.format(i + 1, config.niter) for tag, roi_value in loss.items(): log += ", {}: {:.4f}".format(tag, roi_value) print(log) if (i + 1 ) % config.tensorboard_step == 0 and config.use_tensorboard: for tag, roi_value in loss.items(): writer.add_scalar(tag, roi_value, i + 1) writer.add_images( 'Result/crop_real', imagenet_deprocess_batch(crops_input).float() / 255, i + 1) writer.add_images( 'Result/crop_real_rec', imagenet_deprocess_batch(crops_input_rec).float() / 255, i + 1) writer.add_images( 'Result/crop_rand', imagenet_deprocess_batch(crops_rand).float() / 255, i + 1) writer.add_images('Result/img_real', imagenet_deprocess_batch(imgs).float() / 255, i + 1) writer.add_images( 'Result/img_real_rec', imagenet_deprocess_batch(img_rec).float() / 255, i + 1) writer.add_images( 'Result/img_fake_rand', imagenet_deprocess_batch(img_rand).float() / 255, i + 1) if (i + 1) % config.save_step == 0: # netG_noDP.load_state_dict(new_state_dict) save_model(netG, model_dir=model_save_dir, appendix='netG', iter=i + 1, save_num=2, save_step=config.save_step) save_model(netD_image, model_dir=model_save_dir, appendix='netD_image', iter=i + 1, save_num=2, save_step=config.save_step) save_model(netD_object, model_dir=model_save_dir, appendix='netD_object', iter=i + 1, save_num=2, save_step=config.save_step) save_model(netD_att, model_dir=model_save_dir, appendix='netD_attribute', iter=i + 1, save_num=2, save_step=config.save_step) if config.use_tensorboard: writer.close()
def test_model(model, pickle_files, criterion, input_size=224): since = time.time() val_acc_history = [] phase = 'val' model.eval() running_loss_real = 0.0 running_corrects_real = 0 running_count_real = 0 running_loss_fake = 0.0 running_corrects_fake = 0 running_count_fake = 0 running_loss_fake_shift = 0.0 running_corrects_fake_shift = 0 running_count_fake_shift = 0 # Iterate over data. for i, pickle_file in enumerate(pickle_files): batch = pickle.load(open(pickle_file, 'rb')) imgs, imgs_rand, imgs_shift, objs, boxes, boxes_shift, obj_to_img = batch[ 'imgs'], batch['imgs_rand'], batch['imgs_shift'], batch[ 'objs'], batch['boxes'], batch['boxes_shift'], batch[ 'obj_to_img'] inputs_real = crop_bbox_batch(imgs, boxes, obj_to_img, input_size) inputs_real = inputs_real.to(device) inputs_fake = crop_bbox_batch(imgs_rand, boxes, obj_to_img, input_size) inputs_fake = inputs_fake.to(device) inputs_fake_shift = crop_bbox_batch(imgs_shift, boxes_shift, obj_to_img, input_size) inputs_fake_shift = inputs_fake_shift.to(device) labels = objs.to(device) # forward # track history if only in train with torch.no_grad(): outputs_real = model(inputs_real) loss_real = criterion(outputs_real, labels) _, preds_real = torch.max(outputs_real, 1) outputs_fake = model(inputs_fake) loss_fake = criterion(outputs_fake, labels) _, preds_fake = torch.max(outputs_fake, 1) outputs_fake_shift = model(inputs_fake_shift) loss_fake_shift = criterion(outputs_fake_shift, labels) _, preds_fake_shift = torch.max(outputs_fake_shift, 1) # statistics running_loss_real += loss_real.item() * labels.size(0) running_corrects_real += torch.sum(preds_real == labels.data) running_count_real += labels.size(0) running_loss_fake += loss_fake.item() * labels.size(0) running_corrects_fake += torch.sum(preds_fake == labels.data) running_count_fake += labels.size(0) running_loss_fake_shift += loss_fake_shift.item() * labels.size(0) running_corrects_fake_shift += torch.sum( preds_fake_shift == labels.data) running_count_fake_shift += labels.size(0) if (i + 1) % 20 == 0: print( 'real loss: {:.4f} accu: {:.4f} fake loss: {:.4f} accu: {:.4f} shift loss: {:.4f} accu: {:.4f}' .format(loss_real.item(), torch.mean((preds_real == labels.data).float()), loss_fake.item(), torch.mean((preds_fake == labels.data).float()), loss_fake_shift.item(), torch.mean((preds_fake_shift == labels.data).float()))) epoch_loss_real = running_loss_real / running_count_real epoch_acc_real = running_corrects_real.double() / running_count_real epoch_loss_fake = running_loss_fake / running_count_fake epoch_acc_fake = running_corrects_fake.double() / running_count_fake epoch_loss_fake_shift = running_loss_fake_shift / running_count_fake_shift epoch_acc_fake_shift = running_corrects_fake_shift.double( ) / running_count_fake_shift print('================================================================') print( '{} Real Loss: {:.4f} Acc: {:.4f} Fake Loss: {:.4f} Acc: {:.4f} shift loss: {:.4f} accu: {:.4f}' .format(phase, epoch_loss_real, epoch_acc_real, epoch_loss_fake, epoch_acc_fake, epoch_loss_fake_shift, epoch_acc_fake_shift))
def main(config): cudnn.benchmark = True device = torch.device('cuda:1') log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir( config.exp_name) with open("data/vocab.json", 'r') as f: vocab = json.load(f) att_idx_to_name = np.array(vocab['attribute_idx_to_name']) print(att_idx_to_name) object_idx_to_name = np.array(vocab['object_idx_to_name']) attribute_nums = 106 train_data_loader, val_data_loader = get_dataloader_vg( batch_size=config.batch_size, attribute_embedding=attribute_nums) vocab_num = train_data_loader.dataset.num_objects netG = Generator(num_embeddings=vocab_num, obj_att_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers, obj_size=config.object_size, attribute_dim=attribute_nums).to(device) netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device) netD_att = add_sn(netD_att) start_iter_ = load_model(netD_att, model_dir="~/models/trained_models", appendix='netD_attribute', iter=config.resume_iter) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) data_loader = val_data_loader data_iter = iter(data_loader) cur_obj_start_idx = 0 attributes_pred = np.zeros((253468, attribute_nums)) attributes_gt = None with torch.no_grad(): netG.eval() for i, batch in enumerate(data_iter): print('batch {}'.format(i)) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch z = torch.randn(objs.size(0), config.z_dim) att_idx = attribute.sum(dim=1).nonzero().squeeze() imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \ imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \ obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device) # estimate attributes attribute_est = attribute.clone() att_mask = torch.zeros(attribute.shape[0]) att_mask = att_mask.scatter(0, att_idx, 1).to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, config.object_size) estimated_att = netD_att(crops_input) max_idx = estimated_att.argmax(1) max_idx = max_idx.float() * (~att_mask.byte()).float().to(device) for row in range(attribute.shape[0]): if row not in att_idx: attribute_est[row, int(max_idx[row])] = 1 # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output # predict attribute on generated crops_rand_yes = torch.index_select(crops_rand, 0, att_idx.to(device)) attribute_yes = torch.index_select(attribute, 0, att_idx.to(device)) estimated_att_rand = netD_att(crops_rand_yes) att_cls = torch.sigmoid(estimated_att_rand) for k in att_idx: ll = att_idx.cpu().numpy().tolist().index(k) non0_idx_cls = (att_cls[ll] > 0.9).nonzero() pred_idx = non0_idx_cls.squeeze().cpu().numpy() attributes_pred[cur_obj_start_idx, pred_idx] = 1 cur_obj_start_idx += 1 # construct GT array attributes_gt = attribute_yes.clone().cpu( ) if attributes_gt is None else np.vstack( [attributes_gt, attribute_yes.clone().cpu()]) img_rand = imagenet_deprocess_batch(img_rand) img_shift = imagenet_deprocess_batch(img_shift) img_rec = imagenet_deprocess_batch(img_rec) # attribute modification changed_list = [] src = 2 # 94 blue, 95 black tgt = 95 # 8 red, 2 white for idx, o in enumerate(objs): attribute[idx, [2, 8, 0, 94, 90, 95, 96, 34, 25, 70, 58, 104 ]] = 0 # remove other color attribute[idx, tgt] = 1 attribute_est[idx, [2, 8, 0, 94, 90, 95, 96, 34, 25, 70, 58, 104 ]] = 0 # remove other color attribute_est[idx, tgt] = 1 changed_list.append(idx) # Generate red image z = torch.randn(objs.size(0), config.z_dim).to(device) output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand_y, crops_shift_y, img_rec_y, img_rand_y, img_shift_y, mu, logvar, z_rand_rec, z_rand_shift = output img_rand_y = imagenet_deprocess_batch(img_rand_y) img_shift_y = imagenet_deprocess_batch(img_shift_y) img_rec_y = imagenet_deprocess_batch(img_rec_y) imgs = imagenet_deprocess_batch(imgs) # 2 top k estimated_att = netD_att(crops_rand) max_idx = estimated_att.topk(5)[1] changed_list = [i for i in changed_list if tgt not in max_idx[i]] estimated_att_y = netD_att(crops_rand_y) max_idx_y = estimated_att_y.topk(3)[1] changed_list_success = [ i for i in changed_list if tgt in max_idx_y[i] ] for j in range(imgs.shape[0]): img_np = img_shift[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_shift.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) img_np = img_rand[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_rand.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) img_rec_np = img_rec[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_rec.png'.format(i * config.batch_size + j)) imwrite(img_path, img_rec_np) img_real_np = imgs[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_real.png'.format(i * config.batch_size + j)) imwrite(img_path, img_real_np) cur_obj_success = [ int(objs[c]) for c in changed_list_success if obj_to_img[c] == j ] # save successfully modified images if len(cur_obj_success) > 0: img_shift_y = img_shift_y[j].numpy().transpose(1, 2, 0) img_path_red = os.path.join( result_save_dir, 'img{:06d}_shift_{}_modified.png'.format( i * config.batch_size + j, object_idx_to_name[cur_obj_success])) imwrite(img_path_red, img_shift_y) img_rec_np = img_rec_y[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_rec_modified.png'.format( i * config.batch_size + j, object_idx_to_name[cur_obj_success])) imwrite(img_path, img_rec_np) img_rand_np = img_rand_y[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}_rand_modified.png'.format( i * config.batch_size + j, object_idx_to_name[cur_obj_success])) imwrite(img_path, img_rand_np) # calculate recall precision num_data = attributes_gt.shape[0] count = np.zeros((num_data, 4)) recall, precision = np.zeros((num_data)), np.zeros((num_data)) for i in range(num_data): # tn, fp, fn, tp = confusion_matrix(attributes_pred, attributes_gt).ravel() count[i] = confusion_matrix(attributes_gt[i], attributes_pred[i]).ravel() recall[i] = count[i][3] / (count[i][3] + count[i][2]) precision[i] = safe_division(count[i][3], count[i][3] + count[i][1]) print("average precision = {}".format(precision.mean())) print("average recall = {}".format(recall.mean())) print("average pred # per obj") print((count[:, 1].sum() + count[:, 3].sum()) / count.shape[0]) print("average GT # per obj") print((count[:, 2].sum() + count[:, 3].sum()) / count.shape[0]) print("% of data that predict something") print(((count[:, 1] + count[:, 3]) > 0).sum() / count.shape[0]) print("% of data at least predicted correct once") print((count[:, 3] > 0).sum() / count.shape[0])
def main(config): cudnn.benchmark = True device = torch.device('cuda:0') log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir( config.exp_name) resinet_dir = "~/pickle/128_vg_pkl_resinet50_247k" if not os.path.exists(resinet_dir): os.mkdir(resinet_dir) attribute_nums = 106 if config.dataset == 'vg': train_data_loader, val_data_loader = get_dataloader_vg( batch_size=config.batch_size, attribute_embedding=attribute_nums) elif config.dataset == 'coco': train_data_loader, val_data_loader = get_dataloader_coco( batch_size=config.batch_size) vocab_num = train_data_loader.dataset.num_objects if config.clstm_layers == 0: netG = Generator_nolstm(num_embeddings=vocab_num, embedding_dim=config.embedding_dim, z_dim=config.z_dim).to(device) else: netG = Generator(num_embeddings=vocab_num, obj_att_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers, obj_size=config.object_size, attribute_dim=attribute_nums).to(device) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) netD_att = train_att_change.AttributeDiscriminator( n_attribute=attribute_nums).to(device) netD_att = train_att_change.add_sn(netD_att) start_iter_ = load_model(netD_att, model_dir="~/models/trained_models", appendix='netD_attribute', iter=config.resume_iter) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) data_loader = val_data_loader data_iter = iter(data_loader) with torch.no_grad(): netG.eval() for i, batch in enumerate(data_iter): print('batch {}'.format(i)) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch att_idx = attribute.sum(dim=1).nonzero().squeeze() z = torch.randn(objs.size(0), config.z_dim) imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \ imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \ obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device) # estimate attributes attribute_est = attribute.clone() att_mask = torch.zeros(attribute.shape[0]) att_mask = att_mask.scatter(0, att_idx, 1).to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, config.object_size) estimated_att = netD_att(crops_input) max_idx = estimated_att.argmax(1) max_idx = max_idx.float() * (~att_mask.byte()).float().to(device) for row in range(attribute.shape[0]): if row not in att_idx: attribute_est[row, int(max_idx[row])] = 1 # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output dict_to_save = { 'imgs': imgs, 'imgs_rand': img_rand, 'imgs_shift': img_shift, 'objs': objs, 'boxes': boxes, 'boxes_shift': boxes_shift, 'obj_to_img': obj_to_img } out_name = os.path.join(resinet_dir, 'batch_{}.pkl'.format(i)) pickle.dump(dict_to_save, open(out_name, 'wb')) img_rand = imagenet_deprocess_batch(img_rand) img_shift = imagenet_deprocess_batch(img_shift) # Save the generated images for j in range(img_rand.shape[0]): # layout = draw_layout(boxes[obj_to_img==j], objs[obj_to_img==j], True, dset_mode='vg') # img_path = os.path.join(result_save_dir, 'layout/img{:06d}_layout.png'.format(i*config.batch_size+j)) # imwrite(img_path, layout) img_np = img_rand[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) # layout = draw_layout(boxes_shift[obj_to_img==j], objs[obj_to_img==j], True, dset_mode='vg') # img_path = os.path.join(result_save_dir, 'layout/img{:06d}_layouts.png'.format(i*config.batch_size+j)) # imwrite(img_path, layout) img_np = img_shift[j].numpy().transpose(1, 2, 0) img_path = os.path.join( result_save_dir, 'img{:06d}s.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np)
writer = SummaryWriter(log_save_dir) for i in range(start_iter, niter): try: batch = next(data_iter) except: data_iter = iter(data_loader) batch = next(data_iter) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = \ imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \ obj_to_img, attribute.to(device), masks_shift.to(device), boxes_shift.to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, 64) att_cls = netD_att(crops_input.detach()) att_idx = attribute.sum(dim=1).nonzero().squeeze() att_cls_annotated = torch.index_select(att_cls, 0, att_idx) attribute_annotated = torch.index_select(attribute, 0, att_idx) d_object_att_cls_loss_real = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) d_loss = 0 d_loss += d_object_att_cls_loss_real netD_att.zero_grad()
def main(config): cudnn.benchmark = True device = torch.device('cuda:0') log_save_dir, model_save_dir, sample_save_dir, result_save_dir1, result_save_dir2 = prepare_dir( config.exp_name) resinet_dir = "~/pickle/vg_pkl_resinet50" if not os.path.exists(resinet_dir): os.mkdir(resinet_dir) attribute_nums = 106 if config.dataset == 'vg': train_data_loader, val_data_loader = get_dataloader_vg( batch_size=config.batch_size, attribute_embedding=attribute_nums) elif config.dataset == 'coco': train_data_loader, val_data_loader = get_dataloader_coco( batch_size=config.batch_size) vocab_num = train_data_loader.dataset.num_objects if config.clstm_layers == 0: netG = Generator_nolstm(num_embeddings=vocab_num, embedding_dim=config.embedding_dim, z_dim=config.z_dim).to(device) else: netG = Generator(num_embeddings=vocab_num, obj_att_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers, obj_size=config.object_size, attribute_dim=attribute_nums).to(device) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) netD_att = train_att_change.AttributeDiscriminator( n_attribute=attribute_nums).to(device) netD_att = train_att_change.add_sn(netD_att) start_iter_ = load_model(netD_att, model_dir="~/models/trained_models", appendix='netD_attribute', iter=config.resume_iter) _ = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) data_loader = val_data_loader data_iter = iter(data_loader) L1_dist_b = 0 L1_dist_f = 0 L1_rand_dist = 0 with torch.no_grad(): netG.eval() for i, batch in enumerate(data_iter): print('batch {}'.format(i)) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch att_idx = attribute.sum(dim=1).nonzero().squeeze() z = torch.randn(objs.size(0), config.z_dim) imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift = \ imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), \ obj_to_img, z.to(device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device) # estimate attributes attribute_est = attribute.clone() att_mask = torch.zeros(attribute.shape[0]) att_mask = att_mask.scatter(0, att_idx, 1).to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, config.object_size) estimated_att = netD_att(crops_input) max_idx = estimated_att.argmax(1) max_idx = max_idx.float() * (~att_mask.byte()).float().to(device) for row in range(attribute.shape[0]): if row not in att_idx: attribute_est[row, int(max_idx[row])] = 1 # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output foreground = list( map(lambda x: (x.data[2] - x.data[0]) < 0.5, boxes)) dict_to_save = { 'imgs': imgs, 'imgs_rand': img_rand, 'imgs_shift': img_shift, 'objs': objs, 'boxes': boxes, 'boxes_shift': boxes_shift, 'obj_to_img': obj_to_img, 'foreground': foreground } out_name = os.path.join(resinet_dir, 'batch_{}.pkl'.format(i)) pickle.dump(dict_to_save, open(out_name, 'wb')) img_rand = imagenet_deprocess_batch(img_rand).type(torch.int32) imgs = imagenet_deprocess_batch(imgs).type(torch.int32) img_shift = imagenet_deprocess_batch(img_shift).type(torch.int32) # Save the generated images fore_count = 0 background_dist = 0 foreground_dist = 0 rand_diff = 0 for j in range(img_rand.shape[0]): obj_indices = torch.nonzero(obj_to_img == j).view(-1) foreground_obj_indices = [ int(i) for i in obj_indices if foreground[i] == 1 ] if foreground_obj_indices != []: foreground_mask = torch.max( torch.max(masks[foreground_obj_indices], 0)[0], torch.max(masks_shift[foreground_obj_indices], 0)[0]).byte() else: print("no foreground") foreground_mask = torch.zeros(1, 64, 64).byte().to(device) background_dist += safe_division( torch.abs((img_rand[j] - img_shift[j]) * (~foreground_mask).cpu().int()).sum(), (3 * (~foreground_mask).cpu().float().sum())) img_np = (img_rand[j] * (~foreground_mask).int().cpu()).numpy().transpose( 1, 2, 0) img_path = os.path.join( result_save_dir1, 'img{:06d}.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) img_np = (img_shift[j] * (~foreground_mask).int().cpu()).numpy().transpose( 1, 2, 0) img_path = os.path.join( result_save_dir2, 'img{:06d}.png'.format(i * config.batch_size + j)) imwrite(img_path, img_np) for foreground_i in foreground_obj_indices: try: foreground_dist += torch.abs( torch.masked_select( img_rand[j], masks[foreground_i].cpu().byte()) - torch.masked_select( img_shift[j], masks_shift[foreground_i].cpu(). byte())).sum() / (3 * (masks[foreground_i]).sum()) fore_count += 1 except: continue rand_diff = torch.abs(img_rand[j] - img_shift[j - 1]).sum() / (3 * 64 * 64) L1_dist_b += background_dist / config.batch_size L1_dist_f += safe_division(foreground_dist, fore_count) L1_rand_dist += rand_diff try: print("runing b_dist = %d, f_dist = %d, rand_diff = %d" % (background_dist / config.batch_size, safe_division(foreground_dist, fore_count), rand_diff)) except: print(math.isnan(background_dist / config.batch_size), math.isnan(safe_division(foreground_dist, fore_count)), math.isnan(rand_diff)) print("Background L1 dist = %f" % (L1_dist_b / i)) print("Foreground L1 dist = %f" % (L1_dist_f / i)) print("Rand L1 dist = %f" % (L1_rand_dist / i))
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, input_size=224, start_epoch=0): since = time.time() val_acc_history = [] writer = {"train": SummaryWriter(log_save_dir + exp_name + "_train/"), "val": SummaryWriter(log_save_dir + exp_name + "_val/")} best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 for epoch in range(start_epoch, num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 running_corrects = 0 running_count = 0 # Iterate over data. for i, batch in enumerate(dataloaders[phase]): imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch inputs = crop_bbox_batch(imgs, boxes, obj_to_img, input_size) inputs = inputs.to(device) labels = objs.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): # Get model outputs and calculate loss # Special case for inception because in training it has an auxiliary output. In train # mode we calculate the loss by summing the final output and the auxiliary output # but in testing we only consider the final output. if is_inception and phase == 'train': # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 outputs, aux_outputs = model(inputs) loss1 = criterion(outputs, labels) loss2 = criterion(aux_outputs, labels) loss = loss1 + 0.4 * loss2 else: outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) # backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() # statistics running_loss += loss.item() * labels.size(0) running_corrects += torch.sum(preds == labels.data) running_count += labels.size(0) if (i+1) % 20 == 0: print('epoch: {:04d} iter: {:08d} loss: {:.4f} accu: {:.4f}'.format(epoch+1, i+1, loss.item(), torch.mean((preds == labels.data).float()))) # Logging. loss_logging = {} loss_logging['avg loss'] = loss.item() loss_logging['accu'] = torch.mean((preds == labels.data).float()) if (i + 1) % 100 == 0: for tag, roi_value in loss_logging.items(): writer[phase].add_scalar(tag, roi_value, epoch * iters_per_epoch[phase] + i + 1) epoch_loss = running_loss / running_count epoch_acc = running_corrects.double() / running_count print('================================================================') print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) # deep copy the model if phase == 'val': # and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) save_model(model, model_save_dir, '128_resinet50_vg_best', epoch + 1, save_step=1) if phase == 'val': val_acc_history.append(epoch_acc) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) writer['train'].close() writer['test'].close() # load best model weights model.load_state_dict(best_model_wts) return model, val_acc_history