def agent(agent_id, all_cooked_time, all_cooked_bw, net_params_queue, exp_queue, model_type): torch.set_num_threads(1) net_env = env.Environment(all_cooked_time=all_cooked_time, all_cooked_bw=all_cooked_bw, random_seed=agent_id) with open(LOG_FILE + '_agent_' + str(agent_id), 'w') as log_file: net = A3C(NO_CENTRAL, model_type, [S_INFO, S_LEN], A_DIM, ACTOR_LR_RATE, CRITIC_LR_RATE) # initial synchronization of the network parameters from the coordinator time_stamp = 0 for epoch in range(TOTALEPOCH): actor_net_params = net_params_queue.get() net.hardUpdateActorNetwork(actor_net_params) last_bit_rate = DEFAULT_QUALITY bit_rate = DEFAULT_QUALITY s_batch = [] a_batch = [] r_batch = [] entropy_record = [] state = torch.zeros((1, S_INFO, S_LEN)) # the action is from the last decision # this is to make the framework similar to the real delay, sleep_time, buffer_size, rebuf, \ video_chunk_size, next_video_chunk_sizes, \ end_of_video, video_chunk_remain = \ net_env.get_video_chunk(bit_rate) time_stamp += delay # in ms time_stamp += sleep_time # in ms while not end_of_video and len(s_batch) < TRAIN_SEQ_LEN: last_bit_rate = bit_rate state = state.clone().detach() state = torch.roll(state, -1, dims=-1) state[0, 0, -1] = VIDEO_BIT_RATE[bit_rate] / float( np.max(VIDEO_BIT_RATE)) # last quality state[0, 1, -1] = buffer_size / BUFFER_NORM_FACTOR # 10 sec state[0, 2, -1] = float(video_chunk_size) / float( delay) / M_IN_K # kilo byte / ms state[ 0, 3, -1] = float(delay) / M_IN_K / BUFFER_NORM_FACTOR # 10 sec state[0, 4, :A_DIM] = torch.tensor( next_video_chunk_sizes) / M_IN_K / M_IN_K # mega byte state[0, 5, -1] = min( video_chunk_remain, CHUNK_TIL_VIDEO_END_CAP) / float(CHUNK_TIL_VIDEO_END_CAP) bit_rate = net.actionSelect(state) # Note: we need to discretize the probability into 1/RAND_RANGE steps, # because there is an intrinsic discrepancy in passing single state and batch states delay, sleep_time, buffer_size, rebuf, \ video_chunk_size, next_video_chunk_sizes, \ end_of_video, video_chunk_remain = \ net_env.get_video_chunk(bit_rate) reward = VIDEO_BIT_RATE[bit_rate] / M_IN_K \ - REBUF_PENALTY * rebuf \ - SMOOTH_PENALTY * np.abs(VIDEO_BIT_RATE[bit_rate] - VIDEO_BIT_RATE[last_bit_rate]) / M_IN_K s_batch.append(state) a_batch.append(bit_rate) r_batch.append(reward) entropy_record.append(3) # log time_stamp, bit_rate, buffer_size, reward log_file.write( str(time_stamp) + '\t' + str(VIDEO_BIT_RATE[bit_rate]) + '\t' + str(buffer_size) + '\t' + str(rebuf) + '\t' + str(video_chunk_size) + '\t' + str(delay) + '\t' + str(reward) + '\n') log_file.flush() exp_queue.put([ s_batch, # ignore the first chuck a_batch, # since we don't have the r_batch, # control over it end_of_video, { 'entropy': entropy_record } ]) log_file.write('\n') # so that in the log we know where video ends
def shift_down(tensor): shifted = torch.roll(tensor, 1, 0) shifted[0, :] = 0.0 return shifted
def predict(self, test_df: pd.DataFrame, verbose: bool = False) -> pd.DataFrame: self.prev_group_test_df = test_df.copy() df = test_df.groupby("user_id").apply(self.aggregate) user_ids = df.index.to_list() # (N, seq) content_id_tensor = pad_sequence(df["content_id"].to_list(), batch_first=True) row_id_tensor = pad_sequence(df["row_id"].to_list(), batch_first=True) # (N, seq, dim) feature_tensor = pad_sequence(df["feature"].to_list(), batch_first=True) batch_size, seq_len, dim = feature_tensor.shape seq_len_mask = (content_id_tensor != 0).to(dtype=torch.uint8) is_question_mask = pad_sequence(df["is_question_mask"].to_list(), batch_first=True) initial_state = self.get_state(user_ids) if "y" in df: y = pad_sequence(df["y"].to_list(), batch_first=True) y = torch.roll(y, 1, dims=1) y[:, 0] = -1 y = torch.unsqueeze(y, dim=2).float() else: y = None with torch.no_grad(): if verbose: print("feature_tensor", feature_tensor.shape) print("content_id_tensor", content_id_tensor.shape) pred_logit, states = self.model(content_id=content_id_tensor, bundle_id=None, feature=feature_tensor, user_id=None, mask=seq_len_mask, initial_state=initial_state, ans_prev_correctly=y) if y is None: self.update_state(user_ids=user_ids, states=states) flatten_is_quesion_maks = is_question_mask.view(batch_size * seq_len) flatten_seq_mask = seq_len_mask.view(batch_size * seq_len).bool() flatten_pred = torch.sigmoid(pred_logit.view(batch_size * seq_len)) flatten_row_id = row_id_tensor.view(batch_size * seq_len) flatten_pred = flatten_pred[ flatten_seq_mask & flatten_is_quesion_maks].cpu().data.numpy() flatten_row_id = flatten_row_id[ flatten_seq_mask & flatten_is_quesion_maks].cpu().data.numpy() pred_df = pd.DataFrame({ "row_id": flatten_row_id, "answered_correctly": flatten_pred }) return pred_df
def run_projection( ctx: click.Context, network_pkl: str, initial_learning_rate: float, w_avg_samples: int = 10000, initial_noise_factor: float = 0.05, regularize_noise_weight: float = 1e5, ): """ Project given image to the latent space of pretrained network pickle. Adapted from stylegan3-fun/projector.py """ torch.manual_seed(42) # Load networks. print('Loading networks from "%s"...' % network_pkl) device = torch.device('cuda') with dnnlib.util.open_url(network_pkl) as fp: G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to( device) # Open Spout stream for target images spout = Spout(silent=False, width=1044, height=1088) # TODO set W and H to 720p ? spout.createReceiver('input') spout.createSender('output') # Stabilize the latent space to make things easier (for StyleGAN3's config t and r models) gen_utils.anchor_latent_space(G) # == Adapted from project() in stylegan3-fun/projector.py == # G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # Compute w stats. z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] print('Projecting in W+ latent space...') w_avg = torch.mean(w_samples, dim=0, keepdim=True) # [1, L, C] w_std = (torch.sum((w_samples - w_avg)**2) / w_avg_samples)**0.5 # Setup noise inputs (only for StyleGAN2 models) noise_buffs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } w_noise_scale = w_std * initial_noise_factor # noise scale is constant lr = initial_learning_rate # learning rate is constant # Load the VGG16 feature detector. url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' vgg16 = metric_utils.get_feature_detector(url, device=device) w_opt = w_avg.clone().detach().requires_grad_(True) optimizer = torch.optim.Adam([w_opt] + list(noise_buffs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr # Init noise. for buf in noise_buffs.values(): buf[:] = torch.randn_like(buf) buf.requires_grad = True # == End setup from project() == # # Project continuously while True: # check on close window spout.check() # receive data data = spout.receive() # data = align_face(data, G.img_resolution) ## TOO SLOW !!! #print(data.shape) # Features for target image. Reshape to 256x256 if it's larger to use with VGG16 # target = np.array(target, dtype=np.uint8) target = torch.tensor(data.transpose([2, 0, 1]), device=device) target = target.unsqueeze(0).to(device).to(torch.float32) if target.shape[2] > 256: target = F.interpolate(target, size=(256, 256), mode='area') target_features = vgg16(target, resize_images=False, return_lpips=True) # Synth images from opt_w. w_noise = torch.randn_like(w_opt) * w_noise_scale ws = w_opt + w_noise synth_images = G.synthesis(ws, noise_mode='const') # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. synth_images = (synth_images + 1) * (255 / 2) if synth_images.shape[2] > 256: synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') # Features for synth images. synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) dist = (target_features - synth_features).square().sum() # Noise regularization. reg_loss = 0.0 for v in noise_buffs.values(): noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() while True: reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean()**2 reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean()**2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) loss = dist + reg_loss * regularize_noise_weight # Print in the same line (avoid cluttering the commandline) message = f'dist {dist:.7e} | loss {loss.item():.7e}' print(message, end='\r') # Step optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() # Normalize noise. with torch.no_grad(): for buf in noise_buffs.values(): buf -= buf.mean() buf *= buf.square().mean().rsqrt() # Produce image data = gen_utils.w_to_img(G, dlatents=w_opt.detach()[0], noise_mode='const')[0] spout.send(data)
def shift_left(tensor): shifted = torch.roll(tensor, -1, 1) shifted[:, IMAGE_WIDTH - 1] = 0.0 return shifted
def __call__(self, f): for i in range(self.lattice.Q): f[i] = torch.roll(f[i], shifts=tuple(self.lattice.stencil.e[i]), dims=tuple(np.arange(self.lattice.D))) return f
def scalar_last2first(X): return torch.roll(X,1,-1)
def forward(self, hidden_states, head_mask=None, output_attentions=False): height, width = self.input_resolution batch_size, dim, channels = hidden_states.size() shortcut = hidden_states hidden_states = self.layernorm_before(hidden_states) hidden_states = hidden_states.view(batch_size, height, width, channels) # cyclic shift if self.shift_size > 0: shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_hidden_states = hidden_states # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view( -1, self.window_size * self.window_size, channels) if self.attn_mask is not None: self.attn_mask = self.attn_mask.to(hidden_states_windows.device) self_attention_outputs = self.attention( hidden_states_windows, self.attn_mask, head_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[ 1:] # add self attentions if we output attention weights attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) shifted_windows = window_reverse(attention_windows, self.window_size, height, width) # B H' W' C # reverse cyclic shift if self.shift_size > 0: attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: attention_windows = shifted_windows attention_windows = attention_windows.view(batch_size, height * width, channels) hidden_states = shortcut + self.drop_path(attention_windows) layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) layer_output = hidden_states + self.output(layer_output) outputs = (layer_output, ) + outputs return outputs
def project( G, target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution *, num_steps = 1000, w_avg_samples = 10000, initial_learning_rate = 0.1, initial_noise_factor = 0.05, lr_rampdown_length = 0.25, lr_rampup_length = 0.05, noise_ramp_length = 0.75, regularize_noise_weight = 1e5, verbose = False, device: torch.device ): assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) def logprint(*args): if verbose: print(*args) G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore # Compute w stats. logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 # Setup noise inputs. noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } # Load VGG16 feature detector. # url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' # with dnnlib.util.open_url(url) as f: # vgg16 = torch.jit.load(f).eval().to(device) vgg16 = torch.jit.load('vgg16.pt').eval().to(device) # Features for target image. target_images = target.unsqueeze(0).to(device).to(torch.float32) if target_images.shape[2] > 256: target_images = F.interpolate(target_images, size=(256, 256), mode='area') target_features = vgg16(target_images, resize_images=False, return_lpips=True) w_opt = torch.tensor(torch.from_numpy(w_avg).repeat_interleave(repeats=G.mapping.num_ws, dim=1), dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device) optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) distances = [] # Init noise. for buf in noise_bufs.values(): buf[:] = torch.randn_like(buf) buf.requires_grad = True for step in range(num_steps): # Learning rate schedule. t = step / num_steps w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) lr = initial_learning_rate * lr_ramp for param_group in optimizer.param_groups: param_group['lr'] = lr # Synth images from opt_w. w_noise = torch.randn_like(w_opt) * w_noise_scale ws = (w_opt + w_noise) synth_images = G.synthesis(ws, noise_mode='const') # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. synth_images = (synth_images + 1) * (255/2) if synth_images.shape[2] > 256: synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') # Features for synth images. synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) dist = (target_features - synth_features).square().sum() # Noise regularization. reg_loss = 0.0 for v in noise_bufs.values(): noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d() while True: reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2 reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) distances.append(dist.item()) loss = dist + reg_loss * regularize_noise_weight # Step optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') # Save projected W for each optimization step. w_out[step] = w_opt.detach()[0] # Normalize noise. with torch.no_grad(): for buf in noise_bufs.values(): buf -= buf.mean() buf *= buf.square().mean().rsqrt() return w_out, distances
def add_speed_perturb(self, targets, targ_lens): """Adds speed perturbation and random_shift to the input signals""" min_len = -1 recombine = False if self.hparams.use_speedperturb: # Performing speed change (independently on each source) new_targets = [] recombine = True for i in range(targets.shape[-1]): new_target = self.hparams.speedperturb( targets[:, :, i], targ_lens ) new_targets.append(new_target) if i == 0: min_len = new_target.shape[-1] else: if new_target.shape[-1] < min_len: min_len = new_target.shape[-1] if self.hparams.use_rand_shift: # Performing random_shift (independently on each source) recombine = True for i in range(targets.shape[-1]): rand_shift = torch.randint( self.hparams.min_shift, self.hparams.max_shift, (1,) ) new_targets[i] = new_targets[i].to(self.device) new_targets[i] = torch.roll( new_targets[i], shifts=(rand_shift[0],), dims=1 ) # Re-combination if recombine: if self.hparams.use_speedperturb: targets = torch.zeros( targets.shape[0], min_len, targets.shape[-1], device=targets.device, dtype=torch.float, ) for i, new_target in enumerate(new_targets): targets[:, :, i] = new_targets[i][:, 0:min_len] # this applies the same speed perturb to each source if self.hparams.use_speedperturb_sameforeachsource: targets = targets.permute(0, 2, 1) targets = targets.reshape(-1, targets.shape[-1]) wav_lens = torch.tensor([targets.shape[-1]] * targets.shape[0]).to( self.device ) targets = self.hparams.speedperturb(targets, wav_lens) targets = targets.reshape( -1, self.hparams.num_spks, targets.shape[-1] ) targets = targets.permute(0, 2, 1) mix = targets.sum(-1) return mix, targets
def step(self, action, last_prediction, time_to_guide): # action: log to linear if time_to_guide == True: bandwidth_prediction = last_prediction * pow(2, (2 * action - 1)) self.gcc_estimator.change_bandwidth_estimation( bandwidth_prediction) else: bandwidth_prediction = last_prediction #bandwidth_prediction = action # run the action, get related packet list: packet_list, done = self.gym_env.step(bandwidth_prediction) for pkt in packet_list: packet_info = PacketInfo() packet_info.payload_type = pkt["payload_type"] packet_info.ssrc = pkt["ssrc"] packet_info.sequence_number = pkt["sequence_number"] packet_info.send_timestamp = pkt["send_time_ms"] packet_info.receive_timestamp = pkt["arrival_time_ms"] packet_info.padding_length = pkt["padding_length"] packet_info.header_length = pkt["header_length"] packet_info.payload_size = pkt["payload_size"] packet_info.bandwidth_prediction = bandwidth_prediction self.packet_record.on_receive(packet_info) self.gcc_estimator.report_states(pkt) # calculate state: self.receiving_rate = self.packet_record.calculate_receiving_rate( interval=self.step_time) #todo self.receiving_rate_list.append(self.receiving_rate) # states.append(liner_to_log(receiving_rate)) # self.receiving_rate.append(receiving_rate) # np.delete(self.receiving_rate, 0, axis=0) self.delay = self.packet_record.calculate_average_delay( interval=self.step_time) self.delay_list.append(self.delay) # states.append(min(delay/1000, 1)) # self.delay.append(delay) # np.delete(self.delay, 0, axis=0) self.loss_ratio = self.packet_record.calculate_loss_ratio( interval=self.step_time) self.loss_ratio_list.append(self.loss_ratio) self.bandwidth_prediction = bandwidth_prediction self.bandwidth_prediction_list.append(bandwidth_prediction) self.gcc_decision = self.gcc_estimator.get_estimated_bandwidth() self.state = self.state.clone().detach() self.state = torch.roll(self.state, -1, dims=-1) # states.append(loss_ratio) # self.loss_ratio.append(loss_ratio) # np.delete(self.loss_ratio, 0, axis=0) # latest_prediction = self.packet_record.calculate_latest_prediction() # states.append(liner_to_log(latest_prediction)) # self.prediction_history.append(latest_prediction) # np.delete(self.prediction_history, 0, axis=0) # states = np.vstack((self.receiving_rate, self.delay, self.loss_ratio, self.prediction_history)) # todo: regularization needs to be fixed self.state[0, 0, -1] = self.receiving_rate / 300000.0 self.state[0, 1, -1] = self.delay / 1000.0 self.state[0, 2, -1] = self.loss_ratio self.state[0, 3, -1] = self.bandwidth_prediction / 300000.0 # maintain list length if len(self.receiving_rate_list) == self.config['state_length']: self.receiving_rate_list.pop(0) self.delay_list.pop(0) self.loss_ratio_list.pop(0) # calculate reward: reward = self.get_reward() return self.state, reward, done, self.gcc_decision
def main(): plt.ion() # 开启interactive mode,便于连续plot opt = TrainOptions().parse() # 用于计算的设备 CPU or GPU device = torch.device("cuda" if USE_CUDA else "cpu") # 定义判别器与生成器的网络 #net_d = NLayerDiscriminator(opt.output_nc, opt.ndf, n_layers=3)#batchnorm net_d = Discriminator(opt.output_nc) init_weights(net_d) #net_g = CTGenerator(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=6) net_g = networks.define_G(1, 65, opt.ngf, 'CTnet', opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids) init_weights(net_g) net_d.to(device) net_g.to(device) if load_net: # save_filename = 'net_d%s.pth' % epoch_start # save_path = os.path.join('./check/', save_filename) # load_network(net_d, save_path) save_filename = 'net_g%s.pth' % epoch_start save_path = os.path.join('./check/', save_filename) load_network(net_g, save_path) # 损失函数 #criterion = nn.BCELoss().to(device) criterion = nn.MSELoss().to(device) criterion1 = nn.L1Loss().to(device) # 真假数据的标签 true_lable = Variable(torch.ones(BATCH_SIZE)).to(device) fake_lable = Variable(torch.zeros(BATCH_SIZE)).to(device) # 优化器 #optimizer_d = torch.optim.Adam(net_d.parameters(), lr=0.0008,betas=[0.3,0.9]) optimizer_g = torch.optim.Adam(net_g.parameters(), lr=0.001,betas=[0.9,0.9]) #optimizer_d = torch.optim.AdamW(net_d.parameters(), lr=0.0001) #optimizer_g = torch.optim.AdamW(net_g.parameters(), lr=0.0001) #one = torch.FloatTensor([1]).cuda() #mone = one * -1 dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options dataset_size = len(dataset) # get the number of images in the dataset. for epoch in range(MAX_EPOCH): if epoch==900: at=1 # 为真实数据加上噪声 for ii, data in enumerate(dataset): g_noises = data['A'].cuda() g_noises=g_noises.squeeze(0) real_data = data['B'].cuda() real_data=real_data.squeeze(0) optimizer_g.zero_grad() fake_date = net_g(g_noises) # aa=torch.roll(torch.roll(real_data,32,2),32,3) # a1=aa[0,32,:,:].data.cpu().numpy() # bb=torch.roll(torch.roll(aa.flip(2).flip(3),1,2),1,3) # b1=bb[0,32,:,:].data.cpu().numpy() # c1=b1-a1 #loss=criterion(torch.roll(torch.roll(fake_date,32,2),32,3),torch.roll(torch.roll(real_data,32,2),32,3))\ # +criterion(torch.roll(torch.roll(torch.roll(torch.roll(fake_date,32,2),32,3).flip(2).flip(3),1,2),1,3),torch.roll(torch.roll(real_data,32,2),32,3)) # if ii%2==0: # #loss =criterion(torch.roll(torch.roll(torch.roll(torch.roll(fake_date,32,2),32,3).flip(2).flip(3),1,2),1,3),torch.roll(torch.roll(real_data,32,2),32,3)) # loss = criterion(torch.roll(torch.roll(fake_date,32,2),32,3).flip(2).flip(3),torch.roll(torch.roll(real_data,32,2),32,3)) # # else: # loss=criterion(torch.roll(torch.roll(fake_date,32,2),32,3),torch.roll(torch.roll(real_data,32,2),32,3)) # loss1 = criterion(torch.roll(torch.roll(fake_date,32,2),32,3),torch.roll(torch.roll(real_data,32,2),32,3)) # p1=fake_date[:,:,0:32,0:65] # p2=fake_date[:,:,32,0:32].unsqueeze(2) # p2=torch.cat((p2,fake_date[:,:,32,32].unsqueeze(2).unsqueeze(3)),3) # p2= torch.cat((p2, fake_date[:, :, 32, 0:32].unsqueeze(2).flip(3)), 3) # P=torch.cat((p1,p2),2) # P=torch.cat((P,p1.flip(2).flip(3)),2) # loss1=criterion(P,real_data) loss1 = criterion(fake_date, real_data) #loss2 = criterion(torch.abs(torch.roll(torch.roll(fake_date,32,2),32,3)-torch.roll(torch.roll(torch.roll(torch.roll(fake_date,32,2),32,3).flip(2).flip(3),1,2),1,3)),torch.zeros(fake_date.size()).cuda()) loss2 =criterion1(fake_date,fake_date.flip(2).flip(3))#暂时没有预想的效果 loss=loss1 #loss = loss1 #loss1.backward(retain_graph=True) loss.backward() optimizer_g.step() # #real_data = np.vstack([POINT*POINT + np.random.normal(0, 0.01, SAMPLE_NUM) for _ in range(BATCH_SIZE)]) # #real_data = np.vstack([np.sin(POINT) + np.random.normal(0, 0.01, SAMPLE_NUM) for _ in range(BATCH_SIZE)]) # #real_data = Variable(torch.Tensor(real_data)).to(device) # # 用随机噪声作为生成器的输入 # #g_noises = np.random.randn(BATCH_SIZE, N_GNET) # #g_noises = Variable(torch.Tensor(g_noises)).to(device) # # # # 训练辨别器 # # for p in net_d.parameters(): # reset requires_grad # # p.requires_grad = True # they are set to False below in netG update # # optimizer_d.zero_grad() # # 辨别器辨别真图的loss # d_real = net_d(real_data) # #loss_d_real = criterion(d_real, true_lable) # loss_d_real = -d_real.mean() # #loss_d_real.backward() # # 辨别器辨别假图的loss # fake_date = net_g(g_noises) # d_fake = net_d(fake_date.detach()) # #loss_d_fake = criterion(d_fake, fake_lable) # loss_d_fake =d_fake.mean() # #loss_d_fake.backward() # # # train with gradient penalty # gradient_penalty = calc_gradient_penalty(net_d, real_data, fake_date) # #gradient_penalty.backward() # # D_cost = loss_d_fake + loss_d_real + gradient_penalty # D_cost.backward() # Wasserstein_D = loss_d_real - loss_d_fake # optimizer_d.step() # if ii%CRITIC_ITERS==0: # # 训练生成器 # # for p in net_d.parameters(): # # p.requires_grad = False # to avoid computation # optimizer_g.zero_grad() # fake_date = net_g(g_noises) # d_fake = net_d(fake_date) # # 生成器生成假图的loss # #loss_g = criterion(d_fake, true_lable) # loss_g =-d_fake.mean() # loss_g.backward() # optimizer_g.step() # G_cost = -loss_g # for name, parms in net_g.named_parameters(): # if name=='model.2.weight': # print('层:',name,parms.size(),'-->name:', name, '-->grad_requirs:', parms.requires_grad, \ # ' -->grad_value:', parms.grad[0]) # a = 1 # # for name, parms in self.netG_A.named_parameters(): # 每200步画出生成的数字图片和相关的数据 if ii % 10 == 0: #print(fake_date[0]) plt.ion() # plt.ion() # plt.cla() # # plt.plot(POINT, fake_date[0].to('cpu').detach().numpy(), c='#4AD631', lw=2, # # label="generated line") # 生成网络生成的数据 # # plt.plot(POINT, real_data[0].to('cpu').detach().numpy(), c='#74BCFF', lw=3, label="real sin") # 真实数据 # #prob = (loss_d_real.mean() + 1 - loss_d_fake.mean()) / 2. # # img_test=[] fk_im=toimage(torch.irfft(torch.roll(torch.roll(fake_date,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0)) #fk_im = toimage(torch.irfft(fake_date.permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0)) #fk_im=toimage(fake_date[0,32,:,:].unsqueeze(0).unsqueeze(0)) #img_test.append(fk_im) save_filenamet = 'fake%s.bmp' % epoch img_path = os.path.join('./check/img/', save_filenamet) save_image(fk_im, img_path) rel_im=toimage(torch.irfft(torch.roll(torch.roll(real_data,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0)) # img_test.append(rel_im) save_image(rel_im, os.path.join('./check/img/', 'Real%s.bmp' % epoch)) test(net_g) message = '(epoch: %d, iters: %d, loss1: %.3f, loss2: %.3f) ' % (epoch, ii,loss1,loss2) print(message) # rel_im=toimage(torch.irfft(real_data.squeeze(0).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0)) # img_test.append(rel_im) # # # for it in range(1, 3): # plt.subplot(1, 2, it) # plt.imshow(img_test[it - 1]) # plt.text(-1, 81, 'D accuracy=%.2f ' % (D_cost.mean()), # fontdict={'size': 15}) # plt.text(-1, 85, 'G accuracy=%.2f ' % (G_cost), # fontdict={'size': 15}) # plt.text(-1, 89, 'W accuracy=%.2f ' % (Wasserstein_D), # fontdict={'size': 15}) # plt.text(-1, 95, 'epoch=%.2f ' % (epoch), # fontdict={'size': 15}) # plt.show() # # plt.ylim(-2, 2) # plt.draw(), plt.pause(0.1),plt.clf() # save_filename = 'net_d%s.pth' % epoch # save_path = os.path.join('./check/', save_filename) # torch.save(net_d.cpu().state_dict(), save_path) # net_d.cuda(0) save_filename = 'net_g%s.pth' % epoch save_path = os.path.join('./check/', save_filename) torch.save(net_g.cpu().state_dict(), save_path) net_g.cuda(0)
def train(model, optimizer, train_iterator, params): """Train the model on one epoch Args: model: (torch.nn.Module) the neural network optimizer: (torch.optim) optimizer for parameters of model train_iterator: (generator) a generator that generates batches of data and labels params: (Params) hyperparameters """ # reload weights from checkpoint if specified if args.restore_file: restore_path = os.path.join(args.experiment_dir, args.restore_file) logging.info("Restoring parameters from {}".format(restore_path)) if not os.path.exists(restore_path): raise ("File doesn't exist") checkpoint = torch.load(restore_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) # set model to training mode model.train() # init some necessary variables best_loss = float(Inf) total_loss = 0. steps_count = 0 num_steps = params.train_size // params.batch_size # train for params.num_epochs for epoch in range(params.num_epochs): logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs)) for batch in tqdm.tqdm(train_iterator, total=num_steps): loss = torch.tensor(0).to(params.device).float() output = model(batch.input.to(params.device).float()) # compute loss to minimize distance between center word and # context words based on params.context_window size for i in range(1, params.context_window + 1): loss += net.cos_embedding_loss(output[i:, :, :], output[:-i, :, :]) # negative-samples loss # negative-samples are taken by randomly rolling batch for i in range(params.negative_samples): loss += net.cos_embedding_loss( output, torch.roll(output, randrange(params.batch_size), 1), True) steps_count += 1 total_loss += loss.item() # save best and latest model for every params.steps_to_save if steps_count % params.steps_to_save == 0: mean_loss = total_loss / params.steps_to_save if mean_loss <= best_loss: best_loss = mean_loss path = os.path.join(args.experiment_dir, 'best_loss.pth.tar') torch.save( { 'epoch': epoch + 1, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': mean_loss }, path) path = os.path.join(args.experiment_dir, 'latest.pth.tar') torch.save( { 'epoch': epoch + 1, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': mean_loss }, path) logging.info("- Average training loss: {}".format(mean_loss)) # reset variables steps_count = 0 total_loss = 0 # backprop optimizer.zero_grad() loss.backward() optimizer.step()
def _rotate_image(image, idx): width = image.shape[2] if idx < 0 or idx >= width: return rotated = torch.roll(image, idx, 2) return rotated
def main(): plt.ion() # 开启interactive mode,便于连续plot opt = TrainOptions().parse() # 用于计算的设备 CPU or GPU device = torch.device("cuda" if USE_CUDA else "cpu") # 定义判别器与生成器的网络 #net_d = NLayerDiscriminator(opt.output_nc, opt.ndf, n_layers=3)#batchnorm #net_d = Discriminator(opt.output_nc) net_d_ct =networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids) net_d_dr =networks.define_D(opt.input_nc, opt.ndf, 'ProjNet', opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids) net_g_dr=networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids) #net_g = CTGenerator(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=6) net_g_ct = networks.define_G(1, 65, opt.ngf, 'CTnet', opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids) # init_weights(net_d_dr) # init_weights(net_d_ct) # init_weights(net_g_dr) # init_weights(net_g_ct) net_d_ct.to(device) net_d_dr.to(device) net_g_dr.to(device) net_g_ct.to(device) one = torch.FloatTensor([1]) mone = one * -1 one = one.to(device) mone= mone.to(device) #summary(net_g_dr, (2,65, 65,65)) if load_net: # save_filename = 'net_d%s.pth' % epoch_start # save_path = os.path.join('./check/', save_filename) # load_network(net_d, save_path) save_filename = 'net_g%s.pth' % epoch_start save_path = os.path.join('./check/', save_filename) load_network(net_g_ct, save_path) # 损失函数 #criterion = nn.BCELoss().to(device) criterion = nn.MSELoss().to(device) criterion1 = nn.L1Loss().to(device) # 优化器 optimizer_d = torch.optim.Adam(itertools.chain(net_d_ct.parameters(),net_d_dr.parameters()), lr=0.0001,betas=[0.5,0.9]) optimizer_g = torch.optim.Adam(itertools.chain(net_g_ct.parameters(),net_g_dr.parameters()), lr=0.0001,betas=[0.5,0.9]) #optimizer_d = torch.optim.AdamW(net_d.parameters(), lr=0.0001) #optimizer_g = torch.optim.AdamW(net_g.parameters(), lr=0.0001) #one = torch.FloatTensor([1]).cuda() #mone = one * -1 dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options dataset_size = len(dataset) # get the number of images in the dataset. def gensample(): for image in enumerate(dataset): yield image gen=gensample() ii=0 for epoch in range(MAX_EPOCH): # 为真实数据加上噪声 for it,data in enumerate(dataset): #载入数据 dr_real = autograd.Variable(data['A'].cuda()) dr_real=dr_real.squeeze(0) ct_real = autograd.Variable(data['B'].cuda()) ct_real=ct_real.squeeze(0) # 训练 #内循环 freeze_params(net_g_ct) freeze_params(net_g_dr) unfreeze_params(net_d_ct) unfreeze_params(net_d_dr) ct_fake = autograd.Variable(net_g_ct(dr_real).data) dr_fake = autograd.Variable(net_g_dr(ct_real).data) optimizer_d.zero_grad() loss_dsc_realct = net_d_ct(ct_real).mean() #loss_dsc_realct.backward() loss_dsc_fakect = net_d_ct(ct_fake.detach()).mean() #loss_dsc_fakect.backward() gradient_penalty_ct = calc_gradient_penalty(net_d_ct, ct_real, ct_fake) #gradient_penalty_ct.backward() loss_d_ct=loss_dsc_fakect - loss_dsc_realct+gradient_penalty_ct loss_d_ct.backward() Wd_ct=loss_dsc_realct-loss_dsc_fakect loss_dsc_realdr = net_d_dr(dr_real).mean() #loss_dsc_realdr.backward() loss_dsc_fakedr = net_d_dr(dr_fake.detach()).mean() #loss_dsc_fakedr.backward() gradient_penalty_dr = calc_gradient_penalty(net_d_dr, dr_real, dr_fake) #gradient_penalty_dr.backward() loss_d_dr = loss_dsc_fakedr - loss_dsc_realdr + gradient_penalty_dr loss_d_dr.backward() Wd_dr = loss_dsc_realdr - loss_dsc_fakedr optimizer_d.step() if it%CRITIC_ITERS==0: #if True: unfreeze_params(net_g_ct) freeze_params(net_d_ct) unfreeze_params(net_g_dr) freeze_params(net_d_dr) ct_fake_g=net_g_ct(dr_real) dr_fake_g=net_g_dr(ct_real) #外循环ct_dr # optimizer_g.zero_grad() loss_out_dr=criterion1(net_g_ct(dr_fake_g),ct_real) # loss_out_dr.backward() # optimizer_g.step() #net_g_ct.load_state_dict(dict_g_ct) #外循环dr_ct # optimizer_g.zero_grad() loss_out_ct=criterion1(net_g_dr(ct_fake_g),dr_real) # loss_out_ct.backward() # optimizer_g.step() #内循环gan loss_g_ct = - net_d_ct(ct_fake_g).mean() #loss_g_ct.backward() loss_g_dr = - net_d_dr(dr_fake_g).mean() #loss_g_dr.backward() loss_gan = loss_out_dr + loss_out_ct #loss_gan=loss_out_dr+loss_out_ct+loss_g_ct+loss_g_dr #loss_gan = criterion(net_g_ct(dr_real), ct_real) + criterion(net_g_dr(ct_real), dr_real) optimizer_g.zero_grad() loss_gan.backward() optimizer_g.step() if it%1==0: fk_im=toimage(torch.irfft(torch.roll(torch.roll(ct_fake,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0)) #fk_im=toimage(ct_fake[0,32,:,:].unsqueeze(0)) #img_test.append(fk_im) save_filenamet = 'fakect%s.bmp' % int(epoch/dataset_size) img_path = os.path.join('./check/img/', save_filenamet) save_image(fk_im, img_path) rel_im=toimage(torch.irfft(torch.roll(torch.roll(ct_real,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0)) #rel_im = toimage(ct_real[0,32, :, :].unsqueeze(0)) # img_test.append(rel_im) save_image(rel_im, os.path.join('./check/img/', 'Realct%s.bmp' % int(epoch))) fake_im = toimage(torch.irfft(torch.roll(torch.roll(dr_fake, -128, 2), -128, 3).permute(1, 2, 3, 0), 2, onesided=False)) #fake_im =toimage(dr_fake.squeeze(0)) save_image(fake_im, os.path.join('./check/img/', 'fakedr%s.bmp' % int(epoch))) ceshi(net_g_ct) message = '(epoch: %d, iters: %d, D_ct: %.3f;[real:%.3f;fake:%.3f], G_ct: %.3f, D_dr: %.3f, G_dr: %.3f) ' % (int(epoch), ii,loss_d_ct,loss_dsc_realct,loss_dsc_fakect,loss_g_ct,loss_d_dr,loss_g_dr) print(message) save_filename = 'net_g%s.pth' % epoch save_path = os.path.join('./check/', save_filename) torch.save(net_g_ct.cpu().state_dict(), save_path) net_g_ct.cuda(0)
def shape_change(dstrfs, niter=100, span=None, batch_size=8, verbose=0): #return_shifts=False): """ Align all dSTRFs to the global mean by shifting them along the lag axis. Compute shape change nonlinearity measure on the aligned dSTRFs. Arguments: dstrfs: tensor of dSTRFs with shape [time * channel * lag * frequency] niter: number of iterations span: the maximum shift allowed per iteration [-span, span] batch_size: temporal batch size used for computing the shifts Returns: shape_change: shape change parameter, tensor of shape [channel] dstrfs: shift-corrected and centered dSTRFs, returned if return_shifts is True shifts: Amount of shift used for each time point to achieve the final result, returned if return_shifts is True """ tdim, cdim, ldim, fdim = dstrfs.shape if cdim > batch_size: return torch.cat([ shape_change(dstrfs[:, k * batch_size:(k + 1) * batch_size], niter=niter, span=span, batch_size=batch_size) for k in range(math.ceil(cdim / batch_size)) ]) span = max(ldim - 10, min(ldim, 5)) if span is None else span dstrfs = torch.nn.functional.pad(dstrfs, (0, 0, span, span)) max_shift = ldim // 2 + span shifts = torch.zeros((tdim, cdim), dtype=int, device=dstrfs.device) t1 = time.time() for i in range(niter): dmean = dstrfs.mean(dim=0) shift = torch.zeros((tdim, cdim), dtype=int, device=dstrfs.device) for batch in range(math.ceil(tdim / batch_size)): batch_ind = slice(batch * batch_size, (batch + 1) * batch_size) shift[batch_ind, :] = utils.find_shift(dstrfs[batch_ind], dmean, max_shift) for c in range(cdim): for k in range(-max_shift, max_shift + 1): batch_ind = shift[:, c] == k dstrfs[batch_ind, c] = torch.roll(dstrfs[batch_ind, c], k, 1) power = dstrfs.abs().mean(dim=[0, 3]) center = utils.smooth(power, 9).argmax(dim=1) for c in range(cdim): dstrfs[:, c] = torch.roll(dstrfs[:, c], int(max_shift - center[c]), 1) shifts[:, c] += shift[:, c] + max_shift - center[c] if (shift == 0).all(): break elif verbose >= 1: print('Iteration #%d: %.4f average shift.' % (i + 1, shift.float().mean().cpu()), flush=True) t2 = time.time() if verbose >= 1: if not (shift == 0).all(): print('Failed to converge to solution ({:s} elapsed).'.format( utils.timestr(t2 - t1))) else: print('Converged in {:d} iterations ({:s} elapsed).'.format( i, utils.timestr(t2 - t1))) power = dstrfs.abs().mean(dim=[0, 3]) best_shift = torch.arange(-span, span + 1)[torch.argmax(torch.stack([ torch.roll(power, k, 1)[:, span:-span].sum(dim=1) for k in range(-span, span + 1) ]), dim=0)] for c in range(cdim): dstrfs[:, c] = torch.roll(dstrfs[:, c], int(best_shift[c]), 1) shifts[:, c] += best_shift[c] dstrfs = dstrfs[:, :, span:-span] #if return_shifts: # return complexity(dstrfs), dstrfs.cpu(), shifts.cpu() #else: # return complexity(dstrfs) return complexity(dstrfs)
def shift_right(t: torch.Tensor) -> torch.Tensor: st = torch.roll(t, 1, 0) st[0] = text_encoder.BOS_ID return st
def actor_critic_loss(policy, model, dist_class, train_batch): assert policy.is_recurrent(), "policy must be recurrent" seq_lens = train_batch['seq_lens'] batch_size = seq_lens.shape[0] max_seq_len = torch.max(seq_lens) mask_orig = sequence_mask(seq_lens, max_seq_len) mask = torch.reshape(mask_orig, [-1]) horizon = policy.config['fun_horizon'] manager_horizon_mask = mask_orig.clone() manager_horizon_mask[:, -horizon:] = False manager_horizon_mask = manager_horizon_mask.reshape(-1) _ = model.icm_forward(train_batch[SampleBatch.OBS], train_batch[SampleBatch.NEXT_OBS]) icm_fwd_loss = model.icm_fwd_forward(train_batch[SampleBatch.ACTIONS]) icm_inv_loss = model.icm_inv_forward(train_batch[SampleBatch.ACTIONS]) icm_loss = 0.995 * icm_fwd_loss + 0.005 * icm_inv_loss icm_loss = torch.sum(icm_loss * mask) icm_loss /= batch_size * max_seq_len policy.icm_loss = icm_loss # Hacky way of passing data from sample batch to train batch model.random_select = train_batch['random_select'].reshape( (batch_size, -1)) model.random_goal = train_batch['random_goal'].reshape( (batch_size, max_seq_len, -1)) logits, _ = model.from_batch(train_batch) manager_values, worker_values = model.value_function() manager_latent_state, manager_goal = model.manager_features() manager_latent_state_future = torch.roll(manager_latent_state, -horizon, 1) manager_latent_state_diff = (manager_latent_state_future - manager_latent_state).detach() policy.manager_loss = 10.0 * -torch.sum( train_batch['manager_advantages'] * F.cosine_similarity( manager_latent_state_diff, manager_goal, dim=-1).reshape(-1) * manager_horizon_mask) / (batch_size * max_seq_len) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]) policy.entropy = 3e-4 * -torch.sum(dist.entropy() * mask) / (batch_size * max_seq_len) policy.pi_err = 0.1 * -torch.sum( train_batch['worker_advantages'] * log_probs.reshape(-1) * mask) / ( batch_size * max_seq_len) policy.manager_value_err = torch.sum( torch.pow( (manager_values.reshape(-1) - train_batch['manager_value_targets']) * mask, 2.0)) / (batch_size * max_seq_len) policy.worker_value_err = 0.01 * torch.sum( torch.pow( (worker_values.reshape(-1) - train_batch['worker_value_targets']) * mask, 2.0)) / (batch_size * max_seq_len) overall_err = sum([ policy.pi_err, policy.manager_value_err, policy.worker_value_err, policy.entropy, policy.manager_loss, policy.icm_loss, ]) return overall_err
def scalar_first2last(X): return torch.roll(X,-1,-1)
def forward(self, x1, x2): pos = self.net(torch.cat([x1, x2], 1)) # Positive Samples neg = self.net(torch.cat([torch.roll(x1, 1, 0), x2], 1)) return -softplus(-pos).mean() - softplus( neg).mean(), pos.mean() - neg.exp().mean() + 1
def diff(x): shift_x = torch.roll(x, 1, 2) return ((shift_x - x) + 1) / 2
def train(self, num_episodes=300, use_prev=None, save_path=None): """ :param num_episodes: number of episodes to train :use_prev: use previous weights to start training from :param save_path: where to save the checkpoint at the end of training """ episode_durations = [] if use_prev is not None: load_model(self.policy_net, self.target_net, self.optimizer, use_prev) total_time_steps = 0 for i_episode in tqdm(range(num_episodes)): # Initialize the environment and state self.env.reset() screen_1 = self.get_screen() screen_2 = self.get_screen() screen_3 = self.get_screen() screen_4 = self.get_screen() # state representation is a stack of previous 4 frames state = torch.stack([screen_4, screen_3, screen_2, screen_1], dim=1).view(1, -1, *screen_1.shape[2:]) for t in count(): # Select and perform an action total_time_steps += 1 self.memory.decay_beta(total_time_steps) action = self.select_action(state, total_time_steps) _, reward, done, _ = self.env.step(action.item()) # Observe new state last_state = state current_screen = self.get_screen() if not done: next_state = torch.roll(state, 3, 1) next_state[:, 0:3, :, :] = current_screen else: next_state = None reward = -30.0 # Store the transition in memory reward = torch.tensor([reward], device=device) self.memory.push(state, action, next_state, reward) # Move to the next state state = next_state # Perform one step of the optimization (on the target network) self.optimize_model() if done: episode_durations.append(t + 1) break # Update the target network, copying all weights and biases in DQN if i_episode % TARGET_UPDATE == 0: self.target_net.load_state_dict(self.policy_net.state_dict()) if i_episode != 0 and i_episode % RESULT_UPDATE == 0: print( f'After {i_episode} episodes, total reward: {total_time_steps}' ) print( f'Average reward per episode: {total_time_steps/i_episode}' ) plot_durations(episode_durations) if save_path is not None: save_model(self.policy_net, self.optimizer, save_path) print('Complete') print('Total time steps: ', total_time_steps) self.env.render() self.env.close()
def shift_right(tensor): shifted = torch.roll(tensor, 1, 1) shifted[:, 0] = 0.0 return shifted
def forward(self, x): return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
def shift_up(tensor): shifted = torch.roll(tensor, -1, 0) shifted[IMAGE_WIDTH - 1, :] = 0.0 return shifted
def gen_period_labels(period_len, cutby, rollby, ratio=0): a = torch.arange(0, period_len).repeat(cutby // period_len + 1 + ratio).view(1, -1, 1) a = torch.roll(a, rollby, dims=1) return a[:, :cutby, :]
def forward(self, query, hw_shape): B, L, C = query.shape H, W = hw_shape assert L == H * W, 'input feature has wrong size' query = query.view(B, H, W, C) # pad feature maps to multiples of window size pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) H_pad, W_pad = query.shape[1], query.shape[2] # cyclic shift if self.shift_size > 0: shifted_query = torch.roll(query, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # calculate attention mask for SW-MSA img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 # nW, window_size, window_size, 1 mask_windows = self.window_partition(img_mask) mask_windows = mask_windows.view( -1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( attn_mask == 0, float(0.0)) else: shifted_query = query attn_mask = None # nW*B, window_size, window_size, C query_windows = self.window_partition(shifted_query) # nW*B, window_size*window_size, C query_windows = query_windows.view(-1, self.window_size**2, C) # W-MSA/SW-MSA (nW*B, window_size*window_size, C) attn_windows = self.w_msa(query_windows, mask=attn_mask) # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # B H' W' C shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) x = self.drop(x) return x
def test_fpgadataflow_upsampler(dt, IFMDim, scale, NumChannels, exec_mode): atol = 1e-3 # Create the test model and inputs for it torch_model = PyTorchTestModel(upscale_factor=scale) input_shape = (1, NumChannels, IFMDim, IFMDim) test_in = torch.arange(0, np.prod(np.asarray(input_shape))) # Limit the input to values valid for the given datatype test_in %= dt.max() - dt.min() + 1 test_in += dt.min() # Additionally make sure we always start with 0, for convenience purposes. test_in = torch.roll(test_in, dt.min()) test_in = test_in.view(*input_shape).type(torch.float32) # Get golden PyTorch and ONNX inputs golden_torch_float = torch_model(test_in) export_path = f"{tmpdir}/Upsample_exported.onnx" FINNManager.export(torch_model, input_shape=input_shape, export_path=export_path, opset_version=11) model = ModelWrapper(export_path) input_dict = {model.graph.input[0].name: test_in.numpy().astype(np.int32)} input_dict = {model.graph.input[0].name: test_in.numpy()} golden_output_dict = oxe.execute_onnx(model, input_dict, True) golden_result = golden_output_dict[model.graph.output[0].name] # Make sure PyTorch and ONNX match pyTorch_onnx_match = np.isclose(golden_result, golden_torch_float).all() assert pyTorch_onnx_match, "ONNX and PyTorch upsampling output don't match." # Prep model for execution model = ModelWrapper(export_path) # model = model.transform(TransposeUpsampleIO()) model = model.transform(MakeInputChannelsLast()) model = model.transform(InferDataLayouts()) model = model.transform(absorb.AbsorbTransposeIntoResize()) model = model.transform(InferShapes()) model = model.transform(ForceDataTypeForTensors(dType=dt)) model = model.transform(GiveUniqueNodeNames()) model = model.transform(InferUpsample()) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) # Check that all nodes are UpsampleNearestNeighbour_Batch nodes for n in model.get_finn_nodes(): node_check = n.op_type == "UpsampleNearestNeighbour_Batch" assert node_check, "All nodes should be UpsampleNearestNeighbour_Batch nodes." # Prep sim if exec_mode == "cppsim": model = model.transform(PrepareCppSim()) model = model.transform(CompileCppSim()) model = model.transform(SetExecMode("cppsim")) elif exec_mode == "rtlsim": model = model.transform(GiveUniqueNodeNames()) model = model.transform(PrepareIP("xc7z020clg400-1", 10)) model = model.transform(HLSSynthIP()) model = model.transform(SetExecMode("rtlsim")) model = model.transform(PrepareRTLSim()) else: raise Exception("Unknown exec_mode") # Run sim test_in_transposed = test_in.numpy().transpose(_to_chan_last_args) input_dict = {model.graph.input[0].name: test_in_transposed} output_dict = oxe.execute_onnx(model, input_dict, True) test_result = output_dict[model.graph.output[0].name] output_matches = np.isclose(golden_result, test_result, atol=atol).all() if exec_mode == "cppsim": assert output_matches, "Cppsim output doesn't match ONNX/PyTorch." elif exec_mode == "rtlsim": assert output_matches, "Rtlsim output doesn't match ONNX/PyTorch."
def forward_part1(self, x, mask_matrix): x_shape = x.size() x = self.norm1(x) if len(x_shape) == 5: b, d, h, w, c = x.shape window_size, shift_size = get_window_size( (d, h, w), self.window_size, self.shift_size) pad_l = pad_t = pad_d0 = 0 pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] pad_b = (window_size[1] - h % window_size[1]) % window_size[1] pad_r = (window_size[2] - w % window_size[2]) % window_size[2] x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) _, dp, hp, wp, _ = x.shape dims = [b, dp, hp, wp] elif len(x_shape) == 4: b, h, w, c = x.shape window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) pad_l = pad_t = 0 pad_r = (window_size[0] - h % window_size[0]) % window_size[0] pad_b = (window_size[1] - w % window_size[1]) % window_size[1] x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, hp, wp, _ = x.shape dims = [b, hp, wp] if any(i > 0 for i in shift_size): if len(x_shape) == 5: shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) elif len(x_shape) == 4: shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None x_windows = window_partition(shifted_x, window_size) attn_windows = self.attn(x_windows, mask=attn_mask) attn_windows = attn_windows.view(-1, *(window_size + (c, ))) shifted_x = window_reverse(attn_windows, window_size, dims) if any(i > 0 for i in shift_size): if len(x_shape) == 5: x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) elif len(x_shape) == 4: x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) else: x = shifted_x if len(x_shape) == 5: if pad_d1 > 0 or pad_r > 0 or pad_b > 0: x = x[:, :d, :h, :w, :].contiguous() elif len(x_shape) == 4: if pad_r > 0 or pad_b > 0: x = x[:, :h, :w, :].contiguous() return x
def rollrow(): return lambda x, shift: torch.roll(x, shift, 0)