def __init__(self, netG, label_list, device, out, num_samples=10, batch_size=100, data_range=(-1, 1)): self.netG = netG self.label_list = label_list self.device = device self.out = out self.num_samples = num_samples self.num_columns = len(label_list) self.batch_size = batch_size self.data_range = data_range z_base = netG.sample_z(num_samples).to(device) z = z_base.clone().unsqueeze(1).repeat(1, self.num_columns, 1) self.fixed_z = z.view(-1, netG.latent_dim) s_base = [] for i in range(len(label_list)): s = 0 for j in range(len(label_list[i])): y = torch.tensor(label_list[i][j]).unsqueeze(0).to(device) s += util.one_hot(y, netG.num_classes) / len(label_list[i]) s_base.append(s) self.fixed_s = torch.cat(s_base, dim=0).repeat(self.num_samples, 1)
def __getitem__(self, index): image = util.load_image(self.file_names[index]) label = util.load_label(self.file_names[index]) image, label = util.random_crop(image, label) image, label = util.random_augmentation(image, label) one_hot = util.one_hot(label, self.palette) return np.float32(image) / 255.0, np.float32(one_hot)
def reconstruct_one_frame(self, cat, input_frame=None, mode='sample'): """ reconstruct the input frame :param input_frame: Variable of a tensor [batch_size, c, h, w] :param cat: Variable of a tensor [batch_size, num_categories] :param mode: sample or random for content latent space :return: reconstructed_frame, mean, logvar for content latent space """ # prepare one_hot vector one_hot_matrix = Variable(one_hot(cat.data, self.num_categories)) batch_size = int(cat.shape[0]) if mode == 'random': z_cont = self.get_rand_var(batch_size, self.cont_dim) mean, logvar = None, None else: cont_h = self.contEnc(input_frame) if mode == 'sample': z_cont, mean, logvar = self.contSampler(cont_h, one_hot_matrix) elif mode == 'mean': z_cont, mean, logvar = self.contSampler(cont_h, one_hot_matrix, use_mean=True) else: raise ValueError('mode %s is not supported' % mode) cont_state = self.contMotionStateGen.cont_forward( z_cont, one_hot_matrix) frame, scale_feat = self.videoDec(cont_state) return frame, mean, logvar, scale_feat
def load_all(styles, time_steps): """ Loads all MIDI files as a piano roll. Prepares midi files for training. """ note_data = [] beat_data = [] style_data = [] note_target = [] # TODO: Can speed this up with better parallel loading. Order gaurentee. styles = [y for x in styles for y in x] for style_id, style in enumerate(styles): style_hot = one_hot(style_id, NUM_STYLES) # Parallel process all files into a list of music sequences seqs = Parallel(n_jobs=multiprocessing.cpu_count(), backend='threading')(delayed(load_midi)(f) for f in get_all_files([style])) for seq in seqs: if len(seq) >= time_steps: # Clamp MIDI to note range seq = clamp_midi(seq) # Create training data and labels train_data, label_data = stagger(seq, time_steps) note_data += train_data note_target += label_data note_data = np.array(note_data) note_target = np.array(note_target) return note_data, note_target
def full_test(self, cat, video_len): """ test mode to generate the entire video :param cat: variable, the action class with the size [batch] :param video_len: int, the desired length of the video :return: """ masks = [] frames = [] one_hot_matrix = Variable(one_hot(cat.data, self.num_categories)) # generate the first image first_frame, _, _, _ = self.reconstruct_one_frame(cat, mode='random') frames.append(first_frame) # generate the last motion latent vectors for the generation at each time step batch_size = int(cat.shape[0]) z_motion = self.get_rand_var(batch_size * (video_len - 1), self.motion_dim) one_hot_matrix_tile = one_hot_matrix.unsqueeze(1).repeat( 1, video_len - 1, 1).view(-1, self.num_categories) motion_h_last = self.trajGenerator.fc_forward( z_motion, one_hot_matrix_tile).view(-1, video_len - 1, 16, 8, 8) # get initial state state = None for idx in range(video_len - 1): # get content encoding cont_h = self.contEnc(frames[-1]) z_cont, _, _ = self.contSampler(cont_h, one_hot_matrix, use_mean=True) # get all previous motion encoding if idx > 0: current_diff = frames[-2] - frames[-1] motion_h_prev = self.motionEnc(current_diff) lstm_input = motion_h_prev.unsqueeze(1) _, state = self.trajGenerator(lstm_input, state_0=state) # update the convLSTM with the current motion variable lstm_input = motion_h_last[:, idx].unsqueeze(1) motion_state, _ = self.trajGenerator(lstm_input, state_0=state) # get the input for the content and motion generators cont_state, motion_state = self.contMotionStateGen( z_cont, motion_state.squeeze(), one_hot_matrix) # get the motion kernels and masks, generate the next frame transforms = self.kernelGen(motion_state) frame, _ = self.videoDec(cont_state, transforms) masks.append(transforms['mask']) frames.append(frame) video = torch.stack(frames, dim=1) return video, masks
def interpolate(G, z, shifts_r, shifts_count, dim, direction_size, deformator=None, with_central_border=False): shifted_images = [] tmp = [] shift_time = 0 batch_num = z.shape[0] for shift in np.arange(-shifts_r, shifts_r + 1e-9, shifts_r / shifts_count): shift_time += 1 if deformator is not None: z_deformed = z + deformator(one_hot(dims=direction_size, value=shift, indx=dim).cuda()) else: z_deformed = z + one_hot(dims=direction_size, value=shift, indx=dim).cuda() shifted_image = G(z_deformed).cpu() if shift == 0.0 and with_central_border: shifted_image = add_border(shifted_image) for index in range(shifted_image.shape[0]): tmp.append(shifted_image[index]) for single_img in range(batch_num): for shift in range(shift_time): shifted_images.append(tmp[shift*batch_num + single_img]) return shifted_images
def val(epoch, loader, model, optimizer, scheduler, device, img_size): model.eval() loader = tqdm(loader) sample_size = 8 criterion = nn.BCELoss(reduction='none').to(device) bce_sum = 0 bce_n = 0 with torch.no_grad(): for ep_idx, (episode, _) in enumerate(loader): for i in range(episode.shape[1]): img = episode[:, i] img = util.resize_img(img, type='seg', size=img_size) img = img.to(device) bce_weight = util.loss_weights(img).to(device) img = util.one_hot(img, device=device) out, latent_loss, _ = model(img) recon_loss = criterion(out, img) recon_loss = (recon_loss * bce_weight).mean() latent_loss = latent_loss.mean() bce_sum += recon_loss.item() * img.shape[0] bce_n += img.shape[0] lr = optimizer.param_groups[0]['lr'] loader.set_description(( f'validation; bce: {recon_loss.item():.5f}; ' f'latent: {latent_loss.item():.3f}; avg bce: {bce_sum/bce_n:.5f}; ' f'lr: {lr:.5f}')) model.train() return bce_sum / bce_n
def extract(loader, model, device, img_size, img_type, destination): loader = tqdm(loader) for ep_idx, (episode, _) in enumerate(loader): sequence = None for i in range(episode.shape[1]): img = episode[:, i] img = util.resize_img(img, type=img_type, size=img_size) img = img.to(device) if img_type == 'seg': img = util.one_hot(img, device=device) with torch.no_grad(): _, _, code = model.encode(img) if i == 0: sequence = code.reshape(1, img.shape[0], -1) else: sequence = torch.cat( (sequence, code.reshape(1, img.shape[0], -1)), 0) sequence = sequence.cpu().numpy() np.save(destination + f'/episode_{ep_idx}', sequence)
def predict_next_frame(self, prev_frame, diff_frames, timestep, cat, mode='sample'): """ predict the next frame given the prev_frame, diff_frames :param prev_frame: Variable of a tensor [batch_size, c, h, w] :param diff_frames: Variable of a tensor [batch_size, c, video_len-1, h, w] :param timestep: int which indicates which time step to predict :param cat: Variable of a tensor [batch_size, num_categories] :param mode: sample or random of motion latent space :return: """ # prepare one_hot vector one_hot_matrix = Variable(one_hot(cat.data, self.num_categories)) batch_size, c, h, w = int(prev_frame.shape[0]), int( prev_frame.shape[1]), int(prev_frame.shape[2]), int( prev_frame.shape[3]) #################################### Content encoding ####################################### # get the content feature cont_h = self.contEnc(prev_frame) z_cont, _, _ = self.contSampler(cont_h, one_hot_matrix, use_mean=True) #################################### Motion encoding ####################################### # get the motion feature for all diff maps motion_input = diff_frames[:, :, :timestep].permute( 0, 2, 1, 3, 4).contiguous().view(-1, c, h, w) motion_h = self.motionEnc(motion_input).view(-1, timestep, 16, 8, 8) # get the motion feature for the last diff map motion_h_cur = motion_h[:, -1] ## motion_h_cur: [batch, 512] motion_h_cur = self.motionEnc.fc_forward(motion_h_cur) # z_motion_cur expected shape: [batch, self.motion_dim] if mode == 'sample': z_motion_cur, mean, logvar = self.motionSampler( motion_h_cur, one_hot_matrix) elif mode == 'random': z_motion_cur = self.get_rand_var(batch_size, self.motion_dim) mean, logvar = None, None elif mode == 'mean': z_motion_cur, mean, logvar = self.motionSampler(motion_h_cur, one_hot_matrix, use_mean=True) else: raise ValueError('mode %s is not supported' % mode) # motion_h_cur: [batch_size, 1, 16, 8, 8] # motion_h is a sequence which is passed into convLSTM with the size [batch_size, timestep, 16, 8, 8] motion_h_cur = self.trajGenerator.fc_forward( z_motion_cur, one_hot_matrix).unsqueeze(1) if timestep > 1: motion_h_prev = motion_h[:, :-1] motion_h = torch.cat((motion_h_prev, motion_h_cur), dim=1) else: motion_h = motion_h_cur lstm_input = motion_h motion_states, _ = self.trajGenerator(lstm_input) # motion_state: [batch_size, seq_len, 16, 8, 8] motion_state = motion_states[:, -1] ################################### Generate the next frame ####################################### cont_state, motion_state = self.contMotionStateGen( z_cont, motion_state, one_hot_matrix) transforms = self.kernelGen(motion_state) frame, scale_feat = self.videoDec(cont_state, transforms) return frame, mean, logvar, scale_feat
def reconstruct_seq(self, images, cat, diff_frames, epsilon, mode='sample'): """ The final task for motion stream: reconstruct the entire sequence given the first frame and all diff frames :param images: the ground-truth video [batch, c, t, h, w] :param diff_frames: Variable of a tensor [batch, c, video_len-1, h, w] :param cat: Variable of a tensor [batch_size] :param epsilon: the probability of the ground truth input in the scheduled sampling :param mode: 'sample' for sampling from the motion distribution; 'mean' for picking the mean of the motion distribution; 'random' for sampling from the standard normal distribution N(0, 1) :return: the reconstructed sequence [batch, c, t, h, w], mean, logvar for the motion latent space for all time steps """ frames = [] batch_size = int(cat.shape[0]) video_len = int(diff_frames.shape[2]) + 1 one_hot_matrix = Variable(one_hot(cat.data, self.num_categories)) c, h, w = int(diff_frames.shape[1]), int(diff_frames.shape[3]), int( diff_frames.shape[4]) ########################## reconstruct the first frame ################################## first_frame = images[:, :, 0] first_frame, _, _, _ = self.reconstruct_one_frame( cat, input_frame=first_frame, mode='mean') frames.append(first_frame) ########################### get motion embedding ########################################## # get the gt motion encoded vector without sampling from the generated distribution motion_input = diff_frames.permute(0, 2, 1, 3, 4).contiguous().view(-1, c, h, w) # motion_h expected shape [batch_size x (video_len - 1), gf_dim, 8, 8] motion_h = self.motionEnc(motion_input) # motion_h_fc expected shape [batch_size x (video_len - 1), 512] motion_h_fc = self.motionEnc.fc_forward(motion_h) # encode every timestep to the motion latent space one_hot_matrix_tile = one_hot_matrix.unsqueeze(1).repeat( 1, video_len - 1, 1).view(-1, self.num_categories) if mode == 'random': z_motion = self.get_rand_var(batch_size * (video_len - 1), self.motion_dim) mean, logvar = None, None elif mode == 'sample': z_motion, mean, logvar = self.motionSampler( motion_h_fc, one_hot_matrix_tile) elif mode == 'mean': z_motion, mean, logvar = self.motionSampler(motion_h_fc, one_hot_matrix_tile, use_mean=True) else: raise ValueError('mode %s is not supported' % mode) # get 2D feature map from the latent vector motion_h_last = self.trajGenerator.fc_forward( z_motion, one_hot_matrix_tile).view(-1, video_len - 1, 16, 8, 8) ###################################### Generate the video frame by frame ################################ motion_h = motion_h.view(-1, video_len - 1, 16, 8, 8) state = None scale_feats = [] # generate each frame for idx in range(video_len - 1): # schedule sampling for content encoder if random.random() > (1 - epsilon): cont_h = self.contEnc(images[:, :, idx]) else: cont_h = self.contEnc(frames[-1]) # encode the content z_cont, _, _ = self.contSampler(cont_h, one_hot_matrix, use_mean=True) if idx > 0: # schedule sampling for motion encoder if random.random() > (1 - epsilon): motion_h_prev = motion_h[:, idx - 1] else: current_diff = frames[-2] - frames[-1] motion_h_prev = self.motionEnc(current_diff) # input the previous diff map to update state lstm_input = motion_h_prev.unsqueeze(1) _, state = self.trajGenerator(lstm_input, state_0=state) # encode the motion # motion_h_cur expected shape [batch_size, 1, 16, 8, 8] motion_h_cur = motion_h_last[:, idx].unsqueeze(1) lstm_input = motion_h_cur motion_state, _ = self.trajGenerator(lstm_input, state_0=state) # decode the next frame cont_state, motion_state = self.contMotionStateGen( z_cont, motion_state.squeeze(), one_hot_matrix) transforms = self.kernelGen(motion_state) frame, scale_feat = self.videoDec(cont_state, transforms) frames.append(frame) # record the feature at each scale scale_feats += scale_feat video = torch.stack(frames, dim=2) return video, mean, logvar, scale_feats
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Common definitions for NER """ from utils.util import one_hot LBLS = [ "PER", "ORG", "LOC", "MISC", "O", ] NONE = "O" LMAP = {k: one_hot(5, i) for i, k in enumerate(LBLS)} NUM = "NNNUMMM" UNK = "UUUNKKK" EMBED_SIZE = 50
def update(self): netG, netD = self.netG, self.netD netG.train() netD.train() optimizerG, optimizerD = self.optimizerG, self.optimizerD if self.iteration > 0: # Train G g_batch_size = self.g_bs_multiple * self.batch_size z = netG.sample_z(g_batch_size).to(self.device) y_noisy = torch.cat([self.label_noisy] * self.g_bs_multiple, dim=0).to(self.device) y_clean = self.noise2clean(y_noisy) x_fake = netG(z, y_clean) out_dis, out_cls, _ = netD(x_fake) if self.gan_loss == 'wgan': g_loss_fake = -out_dis.mean() elif self.gan_loss == 'hinge': g_loss_fake = -out_dis.mean() else: g_loss_fake = F.binary_cross_entropy_with_logits( out_dis, torch.ones(g_batch_size).to(self.device)) g_loss_cls = self.lambda_cls_g * F.cross_entropy(out_cls, y_clean) g_loss = 0 g_loss = g_loss + g_loss_fake g_loss = g_loss + g_loss_cls netG.zero_grad() g_loss.backward() optimizerG.step() self.loss['G/loss_fake'] = g_loss_fake.item() self.loss['G/loss_cls'] = g_loss_cls.item() # Train D for i in range(self.num_critic): image_real, label_noisy = next(self.iterator) x_real = image_real.to(self.device) y_noisy = label_noisy.to(self.device) out_dis, out_cls, out_feat = netD(x_real) if self.gan_loss == 'wgan': d_loss_real = -out_dis.mean() elif self.gan_loss == 'hinge': d_loss_real = (F.relu(1. - out_dis)).mean() else: d_loss_real = F.binary_cross_entropy_with_logits( out_dis, torch.ones(self.batch_size).to(self.device)) if self.T is None: d_loss_cls = self.lambda_cls_d * F.cross_entropy( out_cls, y_noisy) else: eps = 1e-8 p = F.softmax(out_cls, dim=1) d_loss_cls = -self.lambda_cls_d * (torch.sum( util.one_hot(y_noisy, netD.num_classes) * torch.log(p.mm(self.T) + eps)) / y_noisy.size(0)) if self.lambda_ct > 0: out_dis_, _, out_feat_ = netD(x_real) d_loss_ct = self.lambda_ct * (out_dis - out_dis_)**2 d_loss_ct = d_loss_ct + (self.lambda_ct * 0.1 * ((out_feat - out_feat_)**2).mean(1)) d_loss_ct = torch.max(d_loss_ct - self.factor_m, 0.0 * (d_loss_ct - self.factor_m)) d_loss_ct = d_loss_ct.mean() z = netG.sample_z(self.batch_size).to(self.device) y_clean = self.noise2clean(y_noisy) x_fake = netG(z, y_clean) out_dis, _, _ = netD(x_fake.detach()) if self.gan_loss == 'wgan': d_loss_fake = out_dis.mean() elif self.gan_loss == 'hinge': d_loss_fake = (F.relu(1. + out_dis)).mean() else: d_loss_fake = F.binary_cross_entropy_with_logits( out_dis, torch.zeros(self.batch_size).to(self.device)) if self.lambda_gp > 0: d_loss_gp = self.lambda_gp * self.gradient_penalty( x_real, x_fake, netD) d_loss = 0 d_loss = d_loss + d_loss_real + d_loss_fake d_loss = d_loss + d_loss_cls if self.lambda_gp > 0: d_loss = d_loss + d_loss_gp if self.lambda_ct > 0: d_loss = d_loss + d_loss_ct netD.zero_grad() d_loss.backward() optimizerD.step() if self.gan_loss == 'wgan' or self.gan_loss == 'hinge': self.loss['D/loss_adv'] = (d_loss_real.item() + d_loss_fake.item()) else: self.loss['D/loss_real'] = d_loss_real.item() self.loss['D/loss_fake'] = d_loss_fake.item() self.loss['D/loss_cls'] = d_loss_cls.item() if self.lambda_gp > 0: self.loss['D/loss_gp'] = d_loss_gp.item() if self.lambda_ct > 0: self.loss['D/loss_ct'] = d_loss_ct.item() self.loss['D/loss'] = d_loss.item() for lr_scheduler in self.lr_schedulers: lr_scheduler.step() self.label_noisy = label_noisy self.iteration += 1
def compute_beat(beat, notes_in_bar): return one_hot(beat % notes_in_bar, notes_in_bar)
def train(epoch, loader, model, optimizer, scheduler, device, img_size): loader = tqdm(loader) model.train() criterion = nn.BCELoss(reduction='none') criterion.to(device) latent_loss_weight = 0.25 sample_size = 8 bce_sum = 0 bce_n = 0 for ep_idx, (episode, _) in enumerate(loader): for i in range(episode.shape[1]): model.zero_grad() # Get, resize and one-hot encode current batch of images img = episode[:, i] img = util.resize_img(img, type='seg', size=img_size) img = img.to(device) #bce_weight = util.loss_weights(img).to(device) bce_weight = util.seg_weights(img, out_channel=13).to(device) img = util.one_hot(img, device=device) out, latent_loss, _ = model(img) recon_loss = criterion(out, img) recon_loss = (recon_loss * bce_weight).mean() latent_loss = latent_loss.mean() loss = recon_loss + latent_loss_weight * latent_loss loss.backward() if scheduler is not None: scheduler.step() optimizer.step() recon_loss_item = recon_loss.item() latent_loss_item = latent_loss.item() bce_sum += recon_loss.item() * img.shape[0] bce_n += img.shape[0] lr = optimizer.param_groups[0]['lr'] loader.set_description(( f'epoch: {epoch + 1}; bce: {recon_loss_item:.5f}; ' f'latent: {latent_loss_item:.3f}; avg bce: {bce_sum/bce_n:.5f}; ' f'lr: {lr:.5f}; ')) if i % 100 == 0: model.eval() sample = img[:sample_size] with torch.no_grad(): out, _, _ = model(sample) #print('id[0]: ', id[0]) # Convert one-hot semantic segmentation to RGB sample = util.seg_to_rgb(sample) out = util.seg_to_rgb(out) utils.save_image( torch.cat([sample, out], 0), f'sample/seg/{str(epoch + 1).zfill(4)}_{img_size}x{img_size}.png', nrow=sample_size, ) model.train()
def predict(loader, model_rnn, model_vqvae, args): model_rnn.eval() model_vqvae.eval() start = 75 * 3 #0 for original with torch.no_grad(): for ep_idx, (episode, _) in enumerate(loader): sequence_top = None for i in range(start, (args.in_steps + args.steps) * 3 + start, 3): img = episode[:, i].to(args.device) img = util.resize_img(img, type=args.img_type, size=args.img_size) if args.img_type == 'seg': img = util.one_hot(img, device=args.device) if i == start: seq_in = img else: seq_in = torch.cat((seq_in, img)) _, _, top = model_vqvae.encode(img.to(args.device)) if i == start: sequence_top = top.reshape(1, img.shape[0], -1) else: sequence_top = torch.cat( (sequence_top, top.reshape(1, img.shape[0], -1)), 0) seq_len = sequence_top.shape[0] inputs_top = sequence_top.long() inputs_top = F.one_hot(inputs_top, num_classes=args.n_embed).float() inputs_top = inputs_top.view(-1, img.shape[0], args.n_embed * 64) # 16 # Forward pass through the rnn out, hidden = model_rnn(inputs_top[:args.in_steps]) # Reshape, argmax and prepare for decoding image out = out[-1].unsqueeze(0) out_top = out.view(-1, args.n_embed, 64) # 16 out_top = torch.argmax(out_top, dim=1) out_top_seq = out_top.view(-1, 8, 8) # 4,4 out = out_top for t in range(args.steps - 1): # One-hot encode previous prediction out = out.long() out = F.one_hot(out, num_classes=args.n_embed).float() out = out.view(-1, img.shape[0], args.n_embed * 64) # 16 # Predict next frame out, hidden = model_rnn(out, hidden=hidden) # Argmax and save out = out.view(-1, args.n_embed, 64) # 16 out = torch.argmax(out, dim=1) out_top_seq = torch.cat((out_top_seq, out.view(-1, 8, 8)), 0) # 4,4 decoded_samples = model_vqvae.decode_code(out_top_seq) # old vqvae channels = 13 if args.img_type == 'seg' else 3 seq_out = torch.zeros(args.in_steps, channels, args.img_size, args.img_size).to(device) #print('seq_out: ', seq_out.shape, 'decoded_samples: ', decoded_samples[0].shape) seq_out = torch.cat((seq_out, decoded_samples), 0) sequence = torch.cat( (seq_in.to(args.device), seq_out.to(args.device))) sequence_rgb = util.seg_to_rgb( sequence) if args.img_type == 'seg' else sequence ########## save images and measure IoU ############ if args.save_images is True: save_individual_images(sequence_rgb, ep_idx, args) utils.save_image( sequence_rgb, f'predictions/test_pred_{ep_idx}.png', nrow=(args.in_steps + args.steps), )