def evaluate(minibatches, net, epoch, img_save_folder, save_every=20): stime = time.time() c = Counter() val_full_loss = 0. val_masked_loss = 0. val_loss = 0. n_val_ims = 0 for i, (batch_start, batch_end) in enumerate(val_minibatches): img_rgbs = val_origs[batch_start:batch_end] img_labs = np.array([cvrgb2lab(img_rgb) for img_rgb in img_rgbs]) input_ = torch.from_numpy(val_ims[batch_start:batch_end]) gt_abs = img_labs[:, ::4, ::4, 1:] target = torch.from_numpy(lookup_enc.encode_points(gt_abs)) input_captions_ = val_words[batch_start:batch_end] input_lengths_ = val_lengths[batch_start:batch_end] input_captions = Variable(torch.from_numpy(\ input_captions_.astype('int32')).long().cuda()) input_caption_lens = torch.from_numpy(\ input_lengths_.astype('int32')).long().cuda() input_ims = Variable(input_.float().cuda()) target = Variable(target.long()).cuda() output, _ = net(input_ims, input_captions, input_caption_lens) # softmax output and multiply by grid dec_inp = nn.Softmax()(output) # 12544x625 AB_vals = dec_inp.mm(cuda_cc) # 12544x2 # reshape and select last image of batch] AB_vals = AB_vals.view(len(img_labs), 56, 56, 2).data.cpu().numpy() n_val_ims += len(AB_vals) for k, (img_rgb, AB_val) in enumerate(zip(img_rgbs, AB_vals)): AB_val = cv2.resize(AB_val, (224, 224), interpolation=cv2.INTER_CUBIC) img_dec = labim2rgb(np.dstack((np.expand_dims(img_labs[k, :, :, 0], axis=2), AB_val))) val_loss += error_metric(img_dec, img_rgb) if k == 0 and i%save_every == 0: word_list = list(input_captions_[k, :input_lengths_[k]]) words = '_'.join(vrev.get(w, 'unk') for w in word_list) img_labs_tosave = labim2rgb(img_labs[k]) cv2.imwrite('%s/%d_%d_bw.jpg'%(img_save_folder, epoch, i), cv2.cvtColor(img_rgbs[k].astype('uint8'), cv2.COLOR_RGB2GRAY)) cv2.imwrite('%s/%d_%d_color.jpg'%(img_save_folder, epoch, i), img_rgbs[k].astype('uint8')) cv2.imwrite('%s/%d_%d_rec_%s.jpg'%(img_save_folder, epoch, i, words), img_dec.astype('uint8')) return val_loss / len(val_minibatches) # , val_masked_loss / len(val_minibatches)
def train(minibatches, net, optimizer, epoch, prior_probs, img_save_folder): stime = time.time() c = Counter() for i, (batch_start, batch_end) in enumerate(minibatches): img_rgbs = train_origs[batch_start:batch_end] img_labs = np.array([cvrgb2lab(img_rgb) for img_rgb in img_rgbs]) input_ = torch.from_numpy(train_ims[batch_start:batch_end]) target = torch.from_numpy(lookup_enc.encode_points(img_labs[:, ::4, ::4, 1:])) # rand_idx = np.random.randint(5) # 5 captions per batch input_captions_ = train_words[batch_start:batch_end] input_lengths_ = train_lengths[batch_start:batch_end] # for now just choose first caption input_captions = Variable(torch.from_numpy(\ input_captions_.astype('int32')).long().to(device)) input_caption_lens = torch.from_numpy(\ input_lengths_.astype('int32')).long().to(device) input_ims = Variable(input_.float().to(device)) target = Variable(target.long()).to(device) optimizer.zero_grad() output, _ = net(input_ims, input_captions, input_caption_lens) loss = loss_function(output, target.view(-1)) loss.backward() optimizer.step() if i % 50 == 0: print ('loss at epoch %d, batch %d / %d = %f, time: %f s' % \ (epoch, i, len(minibatches), loss.data, time.time()-stime)) stime = time.time() if True: # args.logs: # softmax output and multiply by grid dec_inp = nn.Softmax(dim=1)(output) # 12544x625 AB_vals = dec_inp.mm(cuda_cc) # 12544x2 # reshape and select last image of batch] AB_vals = AB_vals.view(len(img_labs), 56, 56, 2)[-1].data.cpu().numpy()[None,:,:,:] AB_vals = cv2.resize(AB_vals[0], (224, 224), interpolation=cv2.INTER_CUBIC) img_dec = labim2rgb(np.dstack((np.expand_dims(img_labs[-1, :, :, 0], axis=2), AB_vals))) img_labs_tosave = labim2rgb(img_labs[-1]) word_list = list(input_captions_[-1, :input_lengths_[-1]]) words = '_'.join(vrev.get(w, 'unk') for w in word_list) cv2.imwrite('%s/%d_%d_bw.jpg'%(img_save_folder, epoch, i), cv2.cvtColor(img_rgbs[-1].astype('uint8'), cv2.COLOR_RGB2GRAY)) cv2.imwrite('%s/%d_%d_color.jpg'%(img_save_folder, epoch, i), img_rgbs[-1].astype('uint8')) cv2.imwrite('%s/%d_%d_rec_%s.jpg'%(img_save_folder, epoch, i, words), img_dec.astype('uint8')) if i == 0: torch.save({ 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer' : optimizer.state_dict(), 'loss': loss.data, }, args.model_save_file+'_' + str(epoch)+'_'+str(i)+'.pth.tar') return net