Exemplo n.º 1
0
def agent(agent_id, all_cooked_time, all_cooked_bw, net_params_queue,
          exp_queue, model_type):
    torch.set_num_threads(1)

    net_env = env.Environment(all_cooked_time=all_cooked_time,
                              all_cooked_bw=all_cooked_bw,
                              random_seed=agent_id)

    with open(LOG_FILE + '_agent_' + str(agent_id), 'w') as log_file:

        net = A3C(NO_CENTRAL, model_type, [S_INFO, S_LEN], A_DIM,
                  ACTOR_LR_RATE, CRITIC_LR_RATE)

        # initial synchronization of the network parameters from the coordinator

        time_stamp = 0
        for epoch in range(TOTALEPOCH):
            actor_net_params = net_params_queue.get()
            net.hardUpdateActorNetwork(actor_net_params)
            last_bit_rate = DEFAULT_QUALITY
            bit_rate = DEFAULT_QUALITY
            s_batch = []
            a_batch = []
            r_batch = []
            entropy_record = []
            state = torch.zeros((1, S_INFO, S_LEN))

            # the action is from the last decision
            # this is to make the framework similar to the real
            delay, sleep_time, buffer_size, rebuf, \
            video_chunk_size, next_video_chunk_sizes, \
            end_of_video, video_chunk_remain = \
                net_env.get_video_chunk(bit_rate)

            time_stamp += delay  # in ms
            time_stamp += sleep_time  # in ms

            while not end_of_video and len(s_batch) < TRAIN_SEQ_LEN:
                last_bit_rate = bit_rate

                state = state.clone().detach()

                state = torch.roll(state, -1, dims=-1)

                state[0, 0, -1] = VIDEO_BIT_RATE[bit_rate] / float(
                    np.max(VIDEO_BIT_RATE))  # last quality
                state[0, 1, -1] = buffer_size / BUFFER_NORM_FACTOR  # 10 sec
                state[0, 2, -1] = float(video_chunk_size) / float(
                    delay) / M_IN_K  # kilo byte / ms
                state[
                    0, 3,
                    -1] = float(delay) / M_IN_K / BUFFER_NORM_FACTOR  # 10 sec
                state[0, 4, :A_DIM] = torch.tensor(
                    next_video_chunk_sizes) / M_IN_K / M_IN_K  # mega byte
                state[0, 5, -1] = min(
                    video_chunk_remain,
                    CHUNK_TIL_VIDEO_END_CAP) / float(CHUNK_TIL_VIDEO_END_CAP)

                bit_rate = net.actionSelect(state)
                # Note: we need to discretize the probability into 1/RAND_RANGE steps,
                # because there is an intrinsic discrepancy in passing single state and batch states

                delay, sleep_time, buffer_size, rebuf, \
                video_chunk_size, next_video_chunk_sizes, \
                end_of_video, video_chunk_remain = \
                    net_env.get_video_chunk(bit_rate)

                reward = VIDEO_BIT_RATE[bit_rate] / M_IN_K \
                         - REBUF_PENALTY * rebuf \
                         - SMOOTH_PENALTY * np.abs(VIDEO_BIT_RATE[bit_rate] -
                                                   VIDEO_BIT_RATE[last_bit_rate]) / M_IN_K

                s_batch.append(state)
                a_batch.append(bit_rate)
                r_batch.append(reward)
                entropy_record.append(3)

                # log time_stamp, bit_rate, buffer_size, reward
                log_file.write(
                    str(time_stamp) + '\t' + str(VIDEO_BIT_RATE[bit_rate]) +
                    '\t' + str(buffer_size) + '\t' + str(rebuf) + '\t' +
                    str(video_chunk_size) + '\t' + str(delay) + '\t' +
                    str(reward) + '\n')
                log_file.flush()

            exp_queue.put([
                s_batch,  # ignore the first chuck
                a_batch,  # since we don't have the
                r_batch,  # control over it
                end_of_video,
                {
                    'entropy': entropy_record
                }
            ])

            log_file.write('\n')  # so that in the log we know where video ends
Exemplo n.º 2
0
def shift_down(tensor):
    shifted = torch.roll(tensor, 1, 0)
    shifted[0, :] = 0.0
    return shifted
Exemplo n.º 3
0
    def predict(self,
                test_df: pd.DataFrame,
                verbose: bool = False) -> pd.DataFrame:
        self.prev_group_test_df = test_df.copy()
        df = test_df.groupby("user_id").apply(self.aggregate)

        user_ids = df.index.to_list()

        # (N, seq)
        content_id_tensor = pad_sequence(df["content_id"].to_list(),
                                         batch_first=True)
        row_id_tensor = pad_sequence(df["row_id"].to_list(), batch_first=True)

        # (N, seq, dim)
        feature_tensor = pad_sequence(df["feature"].to_list(),
                                      batch_first=True)
        batch_size, seq_len, dim = feature_tensor.shape

        seq_len_mask = (content_id_tensor != 0).to(dtype=torch.uint8)
        is_question_mask = pad_sequence(df["is_question_mask"].to_list(),
                                        batch_first=True)

        initial_state = self.get_state(user_ids)
        if "y" in df:
            y = pad_sequence(df["y"].to_list(), batch_first=True)
            y = torch.roll(y, 1, dims=1)
            y[:, 0] = -1
            y = torch.unsqueeze(y, dim=2).float()
        else:
            y = None

        with torch.no_grad():
            if verbose:
                print("feature_tensor", feature_tensor.shape)
                print("content_id_tensor", content_id_tensor.shape)

            pred_logit, states = self.model(content_id=content_id_tensor,
                                            bundle_id=None,
                                            feature=feature_tensor,
                                            user_id=None,
                                            mask=seq_len_mask,
                                            initial_state=initial_state,
                                            ans_prev_correctly=y)
            if y is None:
                self.update_state(user_ids=user_ids, states=states)

        flatten_is_quesion_maks = is_question_mask.view(batch_size * seq_len)
        flatten_seq_mask = seq_len_mask.view(batch_size * seq_len).bool()

        flatten_pred = torch.sigmoid(pred_logit.view(batch_size * seq_len))
        flatten_row_id = row_id_tensor.view(batch_size * seq_len)

        flatten_pred = flatten_pred[
            flatten_seq_mask & flatten_is_quesion_maks].cpu().data.numpy()
        flatten_row_id = flatten_row_id[
            flatten_seq_mask & flatten_is_quesion_maks].cpu().data.numpy()

        pred_df = pd.DataFrame({
            "row_id": flatten_row_id,
            "answered_correctly": flatten_pred
        })

        return pred_df
Exemplo n.º 4
0
def run_projection(
    ctx: click.Context,
    network_pkl: str,
    initial_learning_rate: float,
    w_avg_samples: int = 10000,
    initial_noise_factor: float = 0.05,
    regularize_noise_weight: float = 1e5,
):
    """
    Project given image to the latent space of pretrained network pickle.
    Adapted from stylegan3-fun/projector.py
    """
    torch.manual_seed(42)

    # Load networks.
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(
            device)

    # Open Spout stream for target images
    spout = Spout(silent=False, width=1044,
                  height=1088)  # TODO set W and H to 720p ?
    spout.createReceiver('input')
    spout.createSender('output')

    # Stabilize the latent space to make things easier (for StyleGAN3's config t and r models)
    gen_utils.anchor_latent_space(G)

    # == Adapted from project() in stylegan3-fun/projector.py == #

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device)
    # Compute w stats.
    z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    w_samples = G.mapping(torch.from_numpy(z_samples).to(device),
                          None)  # [N, L, C]
    print('Projecting in W+ latent space...')
    w_avg = torch.mean(w_samples, dim=0, keepdim=True)  # [1, L, C]
    w_std = (torch.sum((w_samples - w_avg)**2) / w_avg_samples)**0.5
    # Setup noise inputs (only for StyleGAN2 models)
    noise_buffs = {
        name: buf
        for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name
    }
    w_noise_scale = w_std * initial_noise_factor  # noise scale is constant
    lr = initial_learning_rate  # learning rate is constant

    # Load the VGG16 feature detector.
    url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
    vgg16 = metric_utils.get_feature_detector(url, device=device)

    w_opt = w_avg.clone().detach().requires_grad_(True)
    optimizer = torch.optim.Adam([w_opt] + list(noise_buffs.values()),
                                 betas=(0.9, 0.999),
                                 lr=initial_learning_rate)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Init noise.
    for buf in noise_buffs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True
    # == End setup from project() == #

    # Project continuously
    while True:
        # check on close window
        spout.check()
        # receive data
        data = spout.receive()
        # data = align_face(data, G.img_resolution) ## TOO SLOW !!!
        #print(data.shape)

        # Features for target image. Reshape to 256x256 if it's larger to use with VGG16
        # target = np.array(target, dtype=np.uint8)
        target = torch.tensor(data.transpose([2, 0, 1]), device=device)
        target = target.unsqueeze(0).to(device).to(torch.float32)
        if target.shape[2] > 256:
            target = F.interpolate(target, size=(256, 256), mode='area')
        target_features = vgg16(target, resize_images=False, return_lpips=True)

        # Synth images from opt_w.
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = w_opt + w_noise
        synth_images = G.synthesis(ws, noise_mode='const')
        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        synth_images = (synth_images + 1) * (255 / 2)
        if synth_images.shape[2] > 256:
            synth_images = F.interpolate(synth_images,
                                         size=(256, 256),
                                         mode='area')

        # Features for synth images.
        synth_features = vgg16(synth_images,
                               resize_images=False,
                               return_lpips=True)
        dist = (target_features - synth_features).square().sum()

        # Noise regularization.
        reg_loss = 0.0
        for v in noise_buffs.values():
            noise = v[None, None, :, :]  # must be [1,1,H,W] for F.avg_pool2d()
            while True:
                reg_loss += (noise *
                             torch.roll(noise, shifts=1, dims=3)).mean()**2
                reg_loss += (noise *
                             torch.roll(noise, shifts=1, dims=2)).mean()**2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)
        loss = dist + reg_loss * regularize_noise_weight
        # Print in the same line (avoid cluttering the commandline)
        message = f'dist {dist:.7e} | loss {loss.item():.7e}'
        print(message, end='\r')

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # Normalize noise.
        with torch.no_grad():
            for buf in noise_buffs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

        # Produce image
        data = gen_utils.w_to_img(G,
                                  dlatents=w_opt.detach()[0],
                                  noise_mode='const')[0]

        spout.send(data)
Exemplo n.º 5
0
def shift_left(tensor):
    shifted = torch.roll(tensor, -1, 1)
    shifted[:, IMAGE_WIDTH - 1] = 0.0
    return shifted
Exemplo n.º 6
0
 def __call__(self, f):
     for i in range(self.lattice.Q):
         f[i] = torch.roll(f[i],
                           shifts=tuple(self.lattice.stencil.e[i]),
                           dims=tuple(np.arange(self.lattice.D)))
     return f
Exemplo n.º 7
0
def scalar_last2first(X):
        return torch.roll(X,1,-1)
Exemplo n.º 8
0
    def forward(self, hidden_states, head_mask=None, output_attentions=False):
        height, width = self.input_resolution
        batch_size, dim, channels = hidden_states.size()
        shortcut = hidden_states

        hidden_states = self.layernorm_before(hidden_states)
        hidden_states = hidden_states.view(batch_size, height, width, channels)

        # cyclic shift
        if self.shift_size > 0:
            shifted_hidden_states = torch.roll(hidden_states,
                                               shifts=(-self.shift_size,
                                                       -self.shift_size),
                                               dims=(1, 2))
        else:
            shifted_hidden_states = hidden_states

        # partition windows
        hidden_states_windows = window_partition(shifted_hidden_states,
                                                 self.window_size)
        hidden_states_windows = hidden_states_windows.view(
            -1, self.window_size * self.window_size, channels)

        if self.attn_mask is not None:
            self.attn_mask = self.attn_mask.to(hidden_states_windows.device)

        self_attention_outputs = self.attention(
            hidden_states_windows,
            self.attn_mask,
            head_mask,
            output_attentions=output_attentions,
        )

        attention_output = self_attention_outputs[0]

        outputs = self_attention_outputs[
            1:]  # add self attentions if we output attention weights

        attention_windows = attention_output.view(-1, self.window_size,
                                                  self.window_size, channels)
        shifted_windows = window_reverse(attention_windows, self.window_size,
                                         height, width)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            attention_windows = torch.roll(shifted_windows,
                                           shifts=(self.shift_size,
                                                   self.shift_size),
                                           dims=(1, 2))
        else:
            attention_windows = shifted_windows

        attention_windows = attention_windows.view(batch_size, height * width,
                                                   channels)

        hidden_states = shortcut + self.drop_path(attention_windows)

        layer_output = self.layernorm_after(hidden_states)
        layer_output = self.intermediate(layer_output)
        layer_output = hidden_states + self.output(layer_output)

        outputs = (layer_output, ) + outputs

        return outputs
def project(
    G,
    target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
    *,
    num_steps                  = 1000,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.1,
    initial_noise_factor       = 0.05,
    lr_rampdown_length         = 0.25,
    lr_rampup_length           = 0.05,
    noise_ramp_length          = 0.75,
    regularize_noise_weight    = 1e5,
    verbose                    = False,
    device: torch.device
):
    assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)

    def logprint(*args):
        if verbose:
            print(*args)

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore

    # Compute w stats.
    logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
    z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)  # [N, L, C]
    w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)       # [N, 1, C]
    w_avg = np.mean(w_samples, axis=0, keepdims=True)      # [1, 1, C]
    w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5

    # Setup noise inputs.
    noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }

    # Load VGG16 feature detector.
    # url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
    # with dnnlib.util.open_url(url) as f:
    #     vgg16 = torch.jit.load(f).eval().to(device)

    vgg16 = torch.jit.load('vgg16.pt').eval().to(device)

    # Features for target image.
    target_images = target.unsqueeze(0).to(device).to(torch.float32)
    if target_images.shape[2] > 256:
        target_images = F.interpolate(target_images, size=(256, 256), mode='area')
    target_features = vgg16(target_images, resize_images=False, return_lpips=True)

    w_opt = torch.tensor(torch.from_numpy(w_avg).repeat_interleave(repeats=G.mapping.num_ws, dim=1), dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
    w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
    optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
    distances = []

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    for step in range(num_steps):
        # Learning rate schedule.
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Synth images from opt_w.
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise)

        synth_images = G.synthesis(ws, noise_mode='const')

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        synth_images = (synth_images + 1) * (255/2)
        if synth_images.shape[2] > 256:
            synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

        # Features for synth images.
        synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
        dist = (target_features - synth_features).square().sum()

        # Noise regularization.
        reg_loss = 0.0
        for v in noise_bufs.values():
            noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
            while True:
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)

        distances.append(dist.item())
        loss = dist + reg_loss * regularize_noise_weight

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')

        # Save projected W for each optimization step.
        w_out[step] = w_opt.detach()[0]

        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

    return w_out, distances
Exemplo n.º 10
0
    def add_speed_perturb(self, targets, targ_lens):
        """Adds speed perturbation and random_shift to the input signals"""

        min_len = -1
        recombine = False

        if self.hparams.use_speedperturb:
            # Performing speed change (independently on each source)
            new_targets = []
            recombine = True

            for i in range(targets.shape[-1]):
                new_target = self.hparams.speedperturb(
                    targets[:, :, i], targ_lens
                )
                new_targets.append(new_target)
                if i == 0:
                    min_len = new_target.shape[-1]
                else:
                    if new_target.shape[-1] < min_len:
                        min_len = new_target.shape[-1]

            if self.hparams.use_rand_shift:
                # Performing random_shift (independently on each source)
                recombine = True
                for i in range(targets.shape[-1]):
                    rand_shift = torch.randint(
                        self.hparams.min_shift, self.hparams.max_shift, (1,)
                    )
                    new_targets[i] = new_targets[i].to(self.device)
                    new_targets[i] = torch.roll(
                        new_targets[i], shifts=(rand_shift[0],), dims=1
                    )

            # Re-combination
            if recombine:
                if self.hparams.use_speedperturb:
                    targets = torch.zeros(
                        targets.shape[0],
                        min_len,
                        targets.shape[-1],
                        device=targets.device,
                        dtype=torch.float,
                    )
                for i, new_target in enumerate(new_targets):
                    targets[:, :, i] = new_targets[i][:, 0:min_len]

        # this applies the same speed perturb to each source
        if self.hparams.use_speedperturb_sameforeachsource:

            targets = targets.permute(0, 2, 1)
            targets = targets.reshape(-1, targets.shape[-1])
            wav_lens = torch.tensor([targets.shape[-1]] * targets.shape[0]).to(
                self.device
            )

            targets = self.hparams.speedperturb(targets, wav_lens)
            targets = targets.reshape(
                -1, self.hparams.num_spks, targets.shape[-1]
            )
            targets = targets.permute(0, 2, 1)

        mix = targets.sum(-1)
        return mix, targets
Exemplo n.º 11
0
    def step(self, action, last_prediction, time_to_guide):
        # action: log to linear
        if time_to_guide == True:
            bandwidth_prediction = last_prediction * pow(2, (2 * action - 1))
            self.gcc_estimator.change_bandwidth_estimation(
                bandwidth_prediction)
        else:
            bandwidth_prediction = last_prediction
        #bandwidth_prediction = action

        # run the action, get related packet list:
        packet_list, done = self.gym_env.step(bandwidth_prediction)

        for pkt in packet_list:
            packet_info = PacketInfo()
            packet_info.payload_type = pkt["payload_type"]
            packet_info.ssrc = pkt["ssrc"]
            packet_info.sequence_number = pkt["sequence_number"]
            packet_info.send_timestamp = pkt["send_time_ms"]
            packet_info.receive_timestamp = pkt["arrival_time_ms"]
            packet_info.padding_length = pkt["padding_length"]
            packet_info.header_length = pkt["header_length"]
            packet_info.payload_size = pkt["payload_size"]
            packet_info.bandwidth_prediction = bandwidth_prediction
            self.packet_record.on_receive(packet_info)
            self.gcc_estimator.report_states(pkt)

        # calculate state:
        self.receiving_rate = self.packet_record.calculate_receiving_rate(
            interval=self.step_time)
        #todo
        self.receiving_rate_list.append(self.receiving_rate)
        # states.append(liner_to_log(receiving_rate))
        # self.receiving_rate.append(receiving_rate)
        # np.delete(self.receiving_rate, 0, axis=0)
        self.delay = self.packet_record.calculate_average_delay(
            interval=self.step_time)
        self.delay_list.append(self.delay)
        # states.append(min(delay/1000, 1))
        # self.delay.append(delay)
        # np.delete(self.delay, 0, axis=0)
        self.loss_ratio = self.packet_record.calculate_loss_ratio(
            interval=self.step_time)
        self.loss_ratio_list.append(self.loss_ratio)
        self.bandwidth_prediction = bandwidth_prediction
        self.bandwidth_prediction_list.append(bandwidth_prediction)
        self.gcc_decision = self.gcc_estimator.get_estimated_bandwidth()

        self.state = self.state.clone().detach()
        self.state = torch.roll(self.state, -1, dims=-1)
        # states.append(loss_ratio)
        # self.loss_ratio.append(loss_ratio)
        # np.delete(self.loss_ratio, 0, axis=0)
        # latest_prediction = self.packet_record.calculate_latest_prediction()
        # states.append(liner_to_log(latest_prediction))
        # self.prediction_history.append(latest_prediction)
        # np.delete(self.prediction_history, 0, axis=0)
        # states = np.vstack((self.receiving_rate, self.delay, self.loss_ratio, self.prediction_history))
        # todo: regularization needs to be fixed
        self.state[0, 0, -1] = self.receiving_rate / 300000.0
        self.state[0, 1, -1] = self.delay / 1000.0
        self.state[0, 2, -1] = self.loss_ratio
        self.state[0, 3, -1] = self.bandwidth_prediction / 300000.0

        # maintain list length
        if len(self.receiving_rate_list) == self.config['state_length']:
            self.receiving_rate_list.pop(0)
            self.delay_list.pop(0)
            self.loss_ratio_list.pop(0)

        # calculate reward:
        reward = self.get_reward()

        return self.state, reward, done, self.gcc_decision
Exemplo n.º 12
0
def main():
    plt.ion()  # 开启interactive mode,便于连续plot
    opt = TrainOptions().parse()
    # 用于计算的设备 CPU or GPU
    device = torch.device("cuda" if USE_CUDA else "cpu")
    # 定义判别器与生成器的网络
    #net_d = NLayerDiscriminator(opt.output_nc, opt.ndf, n_layers=3)#batchnorm
    net_d = Discriminator(opt.output_nc)
    init_weights(net_d)
    #net_g = CTGenerator(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=6)
    net_g = networks.define_G(1, 65, opt.ngf, 'CTnet', opt.norm,
                      not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)
    init_weights(net_g)
    net_d.to(device)
    net_g.to(device)
    if load_net:
        # save_filename = 'net_d%s.pth' % epoch_start
        # save_path = os.path.join('./check/', save_filename)
        # load_network(net_d, save_path)
        save_filename = 'net_g%s.pth' % epoch_start
        save_path = os.path.join('./check/', save_filename)
        load_network(net_g, save_path)
    # 损失函数
    #criterion = nn.BCELoss().to(device)
    criterion = nn.MSELoss().to(device)
    criterion1 = nn.L1Loss().to(device)
    # 真假数据的标签
    true_lable = Variable(torch.ones(BATCH_SIZE)).to(device)
    fake_lable = Variable(torch.zeros(BATCH_SIZE)).to(device)
    # 优化器
    #optimizer_d = torch.optim.Adam(net_d.parameters(), lr=0.0008,betas=[0.3,0.9])
    optimizer_g = torch.optim.Adam(net_g.parameters(), lr=0.001,betas=[0.9,0.9])

    #optimizer_d = torch.optim.AdamW(net_d.parameters(), lr=0.0001)
    #optimizer_g = torch.optim.AdamW(net_g.parameters(), lr=0.0001)
    #one = torch.FloatTensor([1]).cuda()
    #mone = one * -1
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    for epoch in range(MAX_EPOCH):
        if epoch==900:
            at=1
        # 为真实数据加上噪声
        for ii, data in enumerate(dataset):

            g_noises = data['A'].cuda()
            g_noises=g_noises.squeeze(0)
            real_data = data['B'].cuda()
            real_data=real_data.squeeze(0)
            optimizer_g.zero_grad()
            fake_date = net_g(g_noises)
            # aa=torch.roll(torch.roll(real_data,32,2),32,3)
            # a1=aa[0,32,:,:].data.cpu().numpy()
            # bb=torch.roll(torch.roll(aa.flip(2).flip(3),1,2),1,3)
            # b1=bb[0,32,:,:].data.cpu().numpy()
            # c1=b1-a1
            #loss=criterion(torch.roll(torch.roll(fake_date,32,2),32,3),torch.roll(torch.roll(real_data,32,2),32,3))\
                # +criterion(torch.roll(torch.roll(torch.roll(torch.roll(fake_date,32,2),32,3).flip(2).flip(3),1,2),1,3),torch.roll(torch.roll(real_data,32,2),32,3))
            # if ii%2==0:
            #     #loss =criterion(torch.roll(torch.roll(torch.roll(torch.roll(fake_date,32,2),32,3).flip(2).flip(3),1,2),1,3),torch.roll(torch.roll(real_data,32,2),32,3))
            #     loss = criterion(torch.roll(torch.roll(fake_date,32,2),32,3).flip(2).flip(3),torch.roll(torch.roll(real_data,32,2),32,3))
            #
            # else:
            #     loss=criterion(torch.roll(torch.roll(fake_date,32,2),32,3),torch.roll(torch.roll(real_data,32,2),32,3))
            # loss1 = criterion(torch.roll(torch.roll(fake_date,32,2),32,3),torch.roll(torch.roll(real_data,32,2),32,3))
            # p1=fake_date[:,:,0:32,0:65]
            # p2=fake_date[:,:,32,0:32].unsqueeze(2)
            # p2=torch.cat((p2,fake_date[:,:,32,32].unsqueeze(2).unsqueeze(3)),3)
            # p2= torch.cat((p2, fake_date[:, :, 32, 0:32].unsqueeze(2).flip(3)), 3)
            # P=torch.cat((p1,p2),2)
            # P=torch.cat((P,p1.flip(2).flip(3)),2)
            # loss1=criterion(P,real_data)
            loss1 = criterion(fake_date, real_data)
            #loss2 = criterion(torch.abs(torch.roll(torch.roll(fake_date,32,2),32,3)-torch.roll(torch.roll(torch.roll(torch.roll(fake_date,32,2),32,3).flip(2).flip(3),1,2),1,3)),torch.zeros(fake_date.size()).cuda())
            loss2 =criterion1(fake_date,fake_date.flip(2).flip(3))#暂时没有预想的效果
            loss=loss1
            #loss = loss1
            #loss1.backward(retain_graph=True)
            loss.backward()
            optimizer_g.step()


            # #real_data = np.vstack([POINT*POINT + np.random.normal(0, 0.01, SAMPLE_NUM) for _ in range(BATCH_SIZE)])
            # #real_data = np.vstack([np.sin(POINT) + np.random.normal(0, 0.01, SAMPLE_NUM) for _ in range(BATCH_SIZE)])
            # #real_data = Variable(torch.Tensor(real_data)).to(device)
            # # 用随机噪声作为生成器的输入
            # #g_noises = np.random.randn(BATCH_SIZE, N_GNET)
            # #g_noises = Variable(torch.Tensor(g_noises)).to(device)
            #
            #
            # # 训练辨别器
            # # for p in net_d.parameters():  # reset requires_grad
            # #     p.requires_grad = True  # they are set to False below in netG update
            #
            # optimizer_d.zero_grad()
            # # 辨别器辨别真图的loss
            # d_real = net_d(real_data)
            # #loss_d_real = criterion(d_real, true_lable)
            # loss_d_real = -d_real.mean()
            # #loss_d_real.backward()
            # # 辨别器辨别假图的loss
            # fake_date = net_g(g_noises)
            # d_fake = net_d(fake_date.detach())
            # #loss_d_fake = criterion(d_fake, fake_lable)
            # loss_d_fake =d_fake.mean()
            # #loss_d_fake.backward()
            #
            # # train with gradient penalty
            # gradient_penalty = calc_gradient_penalty(net_d, real_data, fake_date)
            # #gradient_penalty.backward()
            #
            # D_cost = loss_d_fake + loss_d_real + gradient_penalty
            # D_cost.backward()
            # Wasserstein_D = loss_d_real - loss_d_fake
            # optimizer_d.step()
            # if ii%CRITIC_ITERS==0:
            #     # 训练生成器
            #     # for p in net_d.parameters():
            #     #     p.requires_grad = False  # to avoid computation
            #     optimizer_g.zero_grad()
            #     fake_date = net_g(g_noises)
            #     d_fake = net_d(fake_date)
            #     # 生成器生成假图的loss
            #     #loss_g = criterion(d_fake, true_lable)
            #     loss_g =-d_fake.mean()
            #     loss_g.backward()
            #     optimizer_g.step()
            #     G_cost = -loss_g
            #     for name, parms in net_g.named_parameters():
            #         if name=='model.2.weight':
            #             print('层:',name,parms.size(),'-->name:', name, '-->grad_requirs:', parms.requires_grad, \
            #               ' -->grad_value:', parms.grad[0])
            #     a = 1
            #     # for name, parms in self.netG_A.named_parameters():

            # 每200步画出生成的数字图片和相关的数据
            if ii % 10 == 0:
                #print(fake_date[0]) plt.ion()
    #             plt.ion()
    #             plt.cla()
    #             # plt.plot(POINT, fake_date[0].to('cpu').detach().numpy(), c='#4AD631', lw=2,
    #             #          label="generated line")  # 生成网络生成的数据
    #             # plt.plot(POINT, real_data[0].to('cpu').detach().numpy(), c='#74BCFF', lw=3, label="real sin")  # 真实数据
    #             #prob = (loss_d_real.mean() + 1 - loss_d_fake.mean()) / 2.
    #
    #             img_test=[]
                 fk_im=toimage(torch.irfft(torch.roll(torch.roll(fake_date,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))
                 #fk_im = toimage(torch.irfft(fake_date.permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))
                 #fk_im=toimage(fake_date[0,32,:,:].unsqueeze(0).unsqueeze(0))
                 #img_test.append(fk_im)
                 save_filenamet = 'fake%s.bmp' % epoch
                 img_path = os.path.join('./check/img/', save_filenamet)
                 save_image(fk_im, img_path)

                 rel_im=toimage(torch.irfft(torch.roll(torch.roll(real_data,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))
                # img_test.append(rel_im)
                 save_image(rel_im, os.path.join('./check/img/', 'Real%s.bmp' % epoch))
                 test(net_g)
                 message = '(epoch: %d, iters: %d, loss1: %.3f, loss2: %.3f) ' % (epoch, ii,loss1,loss2)
                 print(message)

    #             rel_im=toimage(torch.irfft(real_data.squeeze(0).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))
    #             img_test.append(rel_im)
    #
    #
    #             for it in range(1, 3):
    #                 plt.subplot(1, 2, it)
    #                 plt.imshow(img_test[it - 1])
    #             plt.text(-1, 81, 'D accuracy=%.2f ' % (D_cost.mean()),
    #                      fontdict={'size': 15})
    #             plt.text(-1, 85, 'G accuracy=%.2f ' % (G_cost),
    #                      fontdict={'size': 15})
    #             plt.text(-1, 89, 'W accuracy=%.2f ' % (Wasserstein_D),
    #                      fontdict={'size': 15})
    #             plt.text(-1, 95,  'epoch=%.2f ' % (epoch),
    #                      fontdict={'size': 15})
    #             plt.show()
    #            # plt.ylim(-2, 2)
    #             plt.draw(), plt.pause(0.1),plt.clf()
    #     save_filename = 'net_d%s.pth' % epoch
    #     save_path = os.path.join('./check/', save_filename)
    #     torch.save(net_d.cpu().state_dict(), save_path)
    #     net_d.cuda(0)
        save_filename = 'net_g%s.pth' % epoch
        save_path = os.path.join('./check/', save_filename)
        torch.save(net_g.cpu().state_dict(), save_path)
        net_g.cuda(0)
Exemplo n.º 13
0
Arquivo: train.py Projeto: gunnxx/rove
def train(model, optimizer, train_iterator, params):
    """Train the model on one epoch

    Args:
        model: (torch.nn.Module) the neural network
        optimizer: (torch.optim) optimizer for parameters of model
        train_iterator: (generator) a generator that generates batches of data and labels
        params: (Params) hyperparameters
    """

    # reload weights from checkpoint if specified
    if args.restore_file:
        restore_path = os.path.join(args.experiment_dir, args.restore_file)

        logging.info("Restoring parameters from {}".format(restore_path))
        if not os.path.exists(restore_path):
            raise ("File doesn't exist")

        checkpoint = torch.load(restore_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    # set model to training mode
    model.train()

    # init some necessary variables
    best_loss = float(Inf)
    total_loss = 0.
    steps_count = 0
    num_steps = params.train_size // params.batch_size

    # train for params.num_epochs
    for epoch in range(params.num_epochs):
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        for batch in tqdm.tqdm(train_iterator, total=num_steps):
            loss = torch.tensor(0).to(params.device).float()
            output = model(batch.input.to(params.device).float())

            # compute loss to minimize distance between center word and
            # context words based on params.context_window size
            for i in range(1, params.context_window + 1):
                loss += net.cos_embedding_loss(output[i:, :, :],
                                               output[:-i, :, :])

            # negative-samples loss
            # negative-samples are taken by randomly rolling batch
            for i in range(params.negative_samples):
                loss += net.cos_embedding_loss(
                    output, torch.roll(output, randrange(params.batch_size),
                                       1), True)

            steps_count += 1
            total_loss += loss.item()

            # save best and latest model for every params.steps_to_save
            if steps_count % params.steps_to_save == 0:
                mean_loss = total_loss / params.steps_to_save

                if mean_loss <= best_loss:
                    best_loss = mean_loss

                    path = os.path.join(args.experiment_dir,
                                        'best_loss.pth.tar')
                    torch.save(
                        {
                            'epoch': epoch + 1,
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'loss': mean_loss
                        }, path)

                path = os.path.join(args.experiment_dir, 'latest.pth.tar')
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'loss': mean_loss
                    }, path)

                logging.info("- Average training loss: {}".format(mean_loss))

                # reset variables
                steps_count = 0
                total_loss = 0

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
Exemplo n.º 14
0
def _rotate_image(image, idx):
    width = image.shape[2]
    if idx < 0 or idx >= width:
        return
    rotated = torch.roll(image, idx, 2)
    return rotated
Exemplo n.º 15
0
def main():
    plt.ion()  # 开启interactive mode,便于连续plot
    opt = TrainOptions().parse()
    # 用于计算的设备 CPU or GPU
    device = torch.device("cuda" if USE_CUDA else "cpu")

    # 定义判别器与生成器的网络
    #net_d = NLayerDiscriminator(opt.output_nc, opt.ndf, n_layers=3)#batchnorm
    #net_d = Discriminator(opt.output_nc)
    net_d_ct =networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)
    net_d_dr =networks.define_D(opt.input_nc, opt.ndf, 'ProjNet',
                                             opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)
    net_g_dr=networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)

    #net_g = CTGenerator(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=6)
    net_g_ct = networks.define_G(1, 65, opt.ngf, 'CTnet', opt.norm,
                      not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)
    # init_weights(net_d_dr)
    # init_weights(net_d_ct)
    # init_weights(net_g_dr)
    # init_weights(net_g_ct)
    net_d_ct.to(device)
    net_d_dr.to(device)
    net_g_dr.to(device)
    net_g_ct.to(device)

    one = torch.FloatTensor([1])
    mone = one * -1
    one = one.to(device)
    mone= mone.to(device)
    #summary(net_g_dr, (2,65, 65,65))
    if load_net:
        # save_filename = 'net_d%s.pth' % epoch_start
        # save_path = os.path.join('./check/', save_filename)
        # load_network(net_d, save_path)
        save_filename = 'net_g%s.pth' % epoch_start
        save_path = os.path.join('./check/', save_filename)
        load_network(net_g_ct, save_path)
    # 损失函数
    #criterion = nn.BCELoss().to(device)
    criterion = nn.MSELoss().to(device)
    criterion1 = nn.L1Loss().to(device)

    # 优化器
    optimizer_d = torch.optim.Adam(itertools.chain(net_d_ct.parameters(),net_d_dr.parameters()), lr=0.0001,betas=[0.5,0.9])
    optimizer_g = torch.optim.Adam(itertools.chain(net_g_ct.parameters(),net_g_dr.parameters()), lr=0.0001,betas=[0.5,0.9])

    #optimizer_d = torch.optim.AdamW(net_d.parameters(), lr=0.0001)
    #optimizer_g = torch.optim.AdamW(net_g.parameters(), lr=0.0001)
    #one = torch.FloatTensor([1]).cuda()
    #mone = one * -1
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    def gensample():
        for image in enumerate(dataset):
            yield image
    gen=gensample()
    ii=0

    for epoch in range(MAX_EPOCH):
        # 为真实数据加上噪声
        for it,data in enumerate(dataset):


            #载入数据
            dr_real = autograd.Variable(data['A'].cuda())
            dr_real=dr_real.squeeze(0)
            ct_real = autograd.Variable(data['B'].cuda())
            ct_real=ct_real.squeeze(0)
            # 训练
            #内循环
            freeze_params(net_g_ct)
            freeze_params(net_g_dr)
            unfreeze_params(net_d_ct)
            unfreeze_params(net_d_dr)

            ct_fake = autograd.Variable(net_g_ct(dr_real).data)
            dr_fake = autograd.Variable(net_g_dr(ct_real).data)

            optimizer_d.zero_grad()
            loss_dsc_realct = net_d_ct(ct_real).mean()
            #loss_dsc_realct.backward()
            loss_dsc_fakect = net_d_ct(ct_fake.detach()).mean()
            #loss_dsc_fakect.backward()
            gradient_penalty_ct = calc_gradient_penalty(net_d_ct, ct_real, ct_fake)
            #gradient_penalty_ct.backward()
            loss_d_ct=loss_dsc_fakect - loss_dsc_realct+gradient_penalty_ct
            loss_d_ct.backward()
            Wd_ct=loss_dsc_realct-loss_dsc_fakect

            loss_dsc_realdr = net_d_dr(dr_real).mean()
            #loss_dsc_realdr.backward()
            loss_dsc_fakedr = net_d_dr(dr_fake.detach()).mean()
            #loss_dsc_fakedr.backward()
            gradient_penalty_dr = calc_gradient_penalty(net_d_dr, dr_real, dr_fake)
            #gradient_penalty_dr.backward()
            loss_d_dr = loss_dsc_fakedr - loss_dsc_realdr + gradient_penalty_dr
            loss_d_dr.backward()
            Wd_dr = loss_dsc_realdr - loss_dsc_fakedr
            optimizer_d.step()
            if it%CRITIC_ITERS==0:
            #if True:
                unfreeze_params(net_g_ct)
                freeze_params(net_d_ct)
                unfreeze_params(net_g_dr)
                freeze_params(net_d_dr)


                ct_fake_g=net_g_ct(dr_real)
                dr_fake_g=net_g_dr(ct_real)

                #外循环ct_dr
                # optimizer_g.zero_grad()
                loss_out_dr=criterion1(net_g_ct(dr_fake_g),ct_real)
                # loss_out_dr.backward()
                # optimizer_g.step()
                #net_g_ct.load_state_dict(dict_g_ct)

                #外循环dr_ct
                # optimizer_g.zero_grad()
                loss_out_ct=criterion1(net_g_dr(ct_fake_g),dr_real)
                # loss_out_ct.backward()
                # optimizer_g.step()

                #内循环gan
                loss_g_ct = - net_d_ct(ct_fake_g).mean()
                #loss_g_ct.backward()
                loss_g_dr = - net_d_dr(dr_fake_g).mean()
                #loss_g_dr.backward()
                loss_gan = loss_out_dr + loss_out_ct
                #loss_gan=loss_out_dr+loss_out_ct+loss_g_ct+loss_g_dr
                #loss_gan = criterion(net_g_ct(dr_real), ct_real) + criterion(net_g_dr(ct_real), dr_real)
                optimizer_g.zero_grad()
                loss_gan.backward()
                optimizer_g.step()

            if it%1==0:
                fk_im=toimage(torch.irfft(torch.roll(torch.roll(ct_fake,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))

                #fk_im=toimage(ct_fake[0,32,:,:].unsqueeze(0))
                 #img_test.append(fk_im)
                save_filenamet = 'fakect%s.bmp' % int(epoch/dataset_size)
                img_path = os.path.join('./check/img/', save_filenamet)
                save_image(fk_im, img_path)

                rel_im=toimage(torch.irfft(torch.roll(torch.roll(ct_real,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))
                #rel_im = toimage(ct_real[0,32, :, :].unsqueeze(0))
                # img_test.append(rel_im)
                save_image(rel_im, os.path.join('./check/img/', 'Realct%s.bmp' % int(epoch)))

                fake_im = toimage(torch.irfft(torch.roll(torch.roll(dr_fake, -128, 2), -128, 3).permute(1, 2, 3, 0), 2, onesided=False))
                #fake_im =toimage(dr_fake.squeeze(0))
                save_image(fake_im, os.path.join('./check/img/', 'fakedr%s.bmp' % int(epoch)))
                ceshi(net_g_ct)
                message = '(epoch: %d, iters: %d, D_ct: %.3f;[real:%.3f;fake:%.3f], G_ct: %.3f, D_dr: %.3f, G_dr: %.3f) ' % (int(epoch), ii,loss_d_ct,loss_dsc_realct,loss_dsc_fakect,loss_g_ct,loss_d_dr,loss_g_dr)
                print(message)

        save_filename = 'net_g%s.pth' % epoch
        save_path = os.path.join('./check/', save_filename)
        torch.save(net_g_ct.cpu().state_dict(), save_path)
        net_g_ct.cuda(0)
Exemplo n.º 16
0
def shape_change(dstrfs,
                 niter=100,
                 span=None,
                 batch_size=8,
                 verbose=0):  #return_shifts=False):
    """
    Align all dSTRFs to the global mean by shifting them along the lag axis.
    Compute shape change nonlinearity measure on the aligned dSTRFs.
    
    Arguments:
        dstrfs: tensor of dSTRFs with shape [time * channel * lag * frequency]
        niter: number of iterations
        span: the maximum shift allowed per iteration [-span, span]
        batch_size: temporal batch size used for computing the shifts
    
    Returns:
        shape_change: shape change parameter, tensor of shape [channel]
        dstrfs: shift-corrected and centered dSTRFs, returned if return_shifts is True
        shifts: Amount of shift used for each time point to achieve the final result,
            returned if return_shifts is True
    """
    tdim, cdim, ldim, fdim = dstrfs.shape
    if cdim > batch_size:
        return torch.cat([
            shape_change(dstrfs[:, k * batch_size:(k + 1) * batch_size],
                         niter=niter,
                         span=span,
                         batch_size=batch_size)
            for k in range(math.ceil(cdim / batch_size))
        ])

    span = max(ldim - 10, min(ldim, 5)) if span is None else span
    dstrfs = torch.nn.functional.pad(dstrfs, (0, 0, span, span))

    max_shift = ldim // 2 + span
    shifts = torch.zeros((tdim, cdim), dtype=int, device=dstrfs.device)

    t1 = time.time()
    for i in range(niter):
        dmean = dstrfs.mean(dim=0)
        shift = torch.zeros((tdim, cdim), dtype=int, device=dstrfs.device)
        for batch in range(math.ceil(tdim / batch_size)):
            batch_ind = slice(batch * batch_size, (batch + 1) * batch_size)
            shift[batch_ind, :] = utils.find_shift(dstrfs[batch_ind], dmean,
                                                   max_shift)

        for c in range(cdim):
            for k in range(-max_shift, max_shift + 1):
                batch_ind = shift[:, c] == k
                dstrfs[batch_ind, c] = torch.roll(dstrfs[batch_ind, c], k, 1)

        power = dstrfs.abs().mean(dim=[0, 3])
        center = utils.smooth(power, 9).argmax(dim=1)

        for c in range(cdim):
            dstrfs[:, c] = torch.roll(dstrfs[:, c], int(max_shift - center[c]),
                                      1)
            shifts[:, c] += shift[:, c] + max_shift - center[c]

        if (shift == 0).all():
            break
        elif verbose >= 1:
            print('Iteration #%d: %.4f average shift.' %
                  (i + 1, shift.float().mean().cpu()),
                  flush=True)
    t2 = time.time()

    if verbose >= 1:
        if not (shift == 0).all():
            print('Failed to converge to solution ({:s} elapsed).'.format(
                utils.timestr(t2 - t1)))
        else:
            print('Converged in {:d} iterations ({:s} elapsed).'.format(
                i, utils.timestr(t2 - t1)))

    power = dstrfs.abs().mean(dim=[0, 3])
    best_shift = torch.arange(-span, span + 1)[torch.argmax(torch.stack([
        torch.roll(power, k, 1)[:, span:-span].sum(dim=1)
        for k in range(-span, span + 1)
    ]),
                                                            dim=0)]
    for c in range(cdim):
        dstrfs[:, c] = torch.roll(dstrfs[:, c], int(best_shift[c]), 1)
        shifts[:, c] += best_shift[c]
    dstrfs = dstrfs[:, :, span:-span]

    #if return_shifts:
    #    return complexity(dstrfs), dstrfs.cpu(), shifts.cpu()
    #else:
    #    return complexity(dstrfs)
    return complexity(dstrfs)
Exemplo n.º 17
0
 def shift_right(t: torch.Tensor) -> torch.Tensor:
     st = torch.roll(t, 1, 0)
     st[0] = text_encoder.BOS_ID
     return st
Exemplo n.º 18
0
def actor_critic_loss(policy, model, dist_class, train_batch):
    assert policy.is_recurrent(), "policy must be recurrent"

    seq_lens = train_batch['seq_lens']
    batch_size = seq_lens.shape[0]
    max_seq_len = torch.max(seq_lens)
    mask_orig = sequence_mask(seq_lens, max_seq_len)
    mask = torch.reshape(mask_orig, [-1])

    horizon = policy.config['fun_horizon']

    manager_horizon_mask = mask_orig.clone()
    manager_horizon_mask[:, -horizon:] = False
    manager_horizon_mask = manager_horizon_mask.reshape(-1)

    _ = model.icm_forward(train_batch[SampleBatch.OBS],
                          train_batch[SampleBatch.NEXT_OBS])
    icm_fwd_loss = model.icm_fwd_forward(train_batch[SampleBatch.ACTIONS])
    icm_inv_loss = model.icm_inv_forward(train_batch[SampleBatch.ACTIONS])
    icm_loss = 0.995 * icm_fwd_loss + 0.005 * icm_inv_loss
    icm_loss = torch.sum(icm_loss * mask)
    icm_loss /= batch_size * max_seq_len
    policy.icm_loss = icm_loss

    # Hacky way of passing data from sample batch to train batch
    model.random_select = train_batch['random_select'].reshape(
        (batch_size, -1))
    model.random_goal = train_batch['random_goal'].reshape(
        (batch_size, max_seq_len, -1))

    logits, _ = model.from_batch(train_batch)
    manager_values, worker_values = model.value_function()
    manager_latent_state, manager_goal = model.manager_features()

    manager_latent_state_future = torch.roll(manager_latent_state, -horizon, 1)
    manager_latent_state_diff = (manager_latent_state_future -
                                 manager_latent_state).detach()

    policy.manager_loss = 10.0 * -torch.sum(
        train_batch['manager_advantages'] * F.cosine_similarity(
            manager_latent_state_diff, manager_goal, dim=-1).reshape(-1) *
        manager_horizon_mask) / (batch_size * max_seq_len)

    dist = dist_class(logits, model)
    log_probs = dist.logp(train_batch[SampleBatch.ACTIONS])
    policy.entropy = 3e-4 * -torch.sum(dist.entropy() * mask) / (batch_size *
                                                                 max_seq_len)
    policy.pi_err = 0.1 * -torch.sum(
        train_batch['worker_advantages'] * log_probs.reshape(-1) * mask) / (
            batch_size * max_seq_len)

    policy.manager_value_err = torch.sum(
        torch.pow(
            (manager_values.reshape(-1) - train_batch['manager_value_targets'])
            * mask, 2.0)) / (batch_size * max_seq_len)
    policy.worker_value_err = 0.01 * torch.sum(
        torch.pow(
            (worker_values.reshape(-1) - train_batch['worker_value_targets']) *
            mask, 2.0)) / (batch_size * max_seq_len)

    overall_err = sum([
        policy.pi_err,
        policy.manager_value_err,
        policy.worker_value_err,
        policy.entropy,
        policy.manager_loss,
        policy.icm_loss,
    ])
    return overall_err
Exemplo n.º 19
0
def scalar_first2last(X):
        return torch.roll(X,-1,-1)
Exemplo n.º 20
0
 def forward(self, x1, x2):
     pos = self.net(torch.cat([x1, x2], 1))  # Positive Samples
     neg = self.net(torch.cat([torch.roll(x1, 1, 0), x2], 1))
     return -softplus(-pos).mean() - softplus(
         neg).mean(), pos.mean() - neg.exp().mean() + 1
Exemplo n.º 21
0
def diff(x):
    shift_x = torch.roll(x, 1, 2)
    return ((shift_x - x) + 1) / 2
Exemplo n.º 22
0
    def train(self, num_episodes=300, use_prev=None, save_path=None):
        """
        :param num_episodes: number of episodes to train
        :use_prev: use previous weights to start training from
        :param save_path: where to save the checkpoint at the end of training
        """
        episode_durations = []
        if use_prev is not None:
            load_model(self.policy_net, self.target_net, self.optimizer,
                       use_prev)
        total_time_steps = 0
        for i_episode in tqdm(range(num_episodes)):
            # Initialize the environment and state
            self.env.reset()
            screen_1 = self.get_screen()
            screen_2 = self.get_screen()
            screen_3 = self.get_screen()
            screen_4 = self.get_screen()
            # state representation is a stack of previous 4 frames
            state = torch.stack([screen_4, screen_3, screen_2, screen_1],
                                dim=1).view(1, -1, *screen_1.shape[2:])
            for t in count():
                # Select and perform an action
                total_time_steps += 1
                self.memory.decay_beta(total_time_steps)
                action = self.select_action(state, total_time_steps)
                _, reward, done, _ = self.env.step(action.item())
                # Observe new state
                last_state = state
                current_screen = self.get_screen()
                if not done:
                    next_state = torch.roll(state, 3, 1)
                    next_state[:, 0:3, :, :] = current_screen
                else:
                    next_state = None
                    reward = -30.0

                # Store the transition in memory
                reward = torch.tensor([reward], device=device)
                self.memory.push(state, action, next_state, reward)

                # Move to the next state
                state = next_state

                # Perform one step of the optimization (on the target network)
                self.optimize_model()
                if done:
                    episode_durations.append(t + 1)
                    break
            # Update the target network, copying all weights and biases in DQN
            if i_episode % TARGET_UPDATE == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict())

            if i_episode != 0 and i_episode % RESULT_UPDATE == 0:
                print(
                    f'After {i_episode} episodes, total reward: {total_time_steps}'
                )
                print(
                    f'Average reward per episode: {total_time_steps/i_episode}'
                )

        plot_durations(episode_durations)
        if save_path is not None:
            save_model(self.policy_net, self.optimizer, save_path)
        print('Complete')
        print('Total time steps: ', total_time_steps)
        self.env.render()
        self.env.close()
Exemplo n.º 23
0
def shift_right(tensor):
    shifted = torch.roll(tensor, 1, 1)
    shifted[:, 0] = 0.0
    return shifted
 def forward(self, x):
     return torch.roll(x,
                       shifts=(self.displacement, self.displacement),
                       dims=(1, 2))
Exemplo n.º 25
0
def shift_up(tensor):
    shifted = torch.roll(tensor, -1, 0)
    shifted[IMAGE_WIDTH - 1, :] = 0.0
    return shifted
Exemplo n.º 26
0
def gen_period_labels(period_len, cutby, rollby, ratio=0):
    a = torch.arange(0, period_len).repeat(cutby // period_len + 1 + ratio).view(1, -1, 1)
    a = torch.roll(a, rollby, dims=1)
    return a[:, :cutby, :]
Exemplo n.º 27
0
    def forward(self, query, hw_shape):
        B, L, C = query.shape
        H, W = hw_shape
        assert L == H * W, 'input feature has wrong size'
        query = query.view(B, H, W, C)

        # pad feature maps to multiples of window size
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
        H_pad, W_pad = query.shape[1], query.shape[2]

        # cyclic shift
        if self.shift_size > 0:
            shifted_query = torch.roll(query,
                                       shifts=(-self.shift_size,
                                               -self.shift_size),
                                       dims=(1, 2))

            # calculate attention mask for SW-MSA
            img_mask = torch.zeros((1, H_pad, W_pad, 1),
                                   device=query.device)  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size,
                              -self.shift_size), slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size,
                              -self.shift_size), slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            # nW, window_size, window_size, 1
            mask_windows = self.window_partition(img_mask)
            mask_windows = mask_windows.view(
                -1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0,
                                              float(-100.0)).masked_fill(
                                                  attn_mask == 0, float(0.0))
        else:
            shifted_query = query
            attn_mask = None

        # nW*B, window_size, window_size, C
        query_windows = self.window_partition(shifted_query)
        # nW*B, window_size*window_size, C
        query_windows = query_windows.view(-1, self.window_size**2, C)

        # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
        attn_windows = self.w_msa(query_windows, mask=attn_mask)

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size,
                                         self.window_size, C)

        # B H' W' C
        shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x,
                           shifts=(self.shift_size, self.shift_size),
                           dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        x = self.drop(x)
        return x
Exemplo n.º 28
0
def test_fpgadataflow_upsampler(dt, IFMDim, scale, NumChannels, exec_mode):
    atol = 1e-3
    # Create the test model and inputs for it
    torch_model = PyTorchTestModel(upscale_factor=scale)
    input_shape = (1, NumChannels, IFMDim, IFMDim)
    test_in = torch.arange(0, np.prod(np.asarray(input_shape)))
    # Limit the input to values valid for the given datatype
    test_in %= dt.max() - dt.min() + 1
    test_in += dt.min()
    # Additionally make sure we always start with 0, for convenience purposes.
    test_in = torch.roll(test_in, dt.min())
    test_in = test_in.view(*input_shape).type(torch.float32)

    # Get golden PyTorch and ONNX inputs
    golden_torch_float = torch_model(test_in)
    export_path = f"{tmpdir}/Upsample_exported.onnx"
    FINNManager.export(torch_model,
                       input_shape=input_shape,
                       export_path=export_path,
                       opset_version=11)
    model = ModelWrapper(export_path)
    input_dict = {model.graph.input[0].name: test_in.numpy().astype(np.int32)}
    input_dict = {model.graph.input[0].name: test_in.numpy()}
    golden_output_dict = oxe.execute_onnx(model, input_dict, True)
    golden_result = golden_output_dict[model.graph.output[0].name]

    # Make sure PyTorch and ONNX match
    pyTorch_onnx_match = np.isclose(golden_result, golden_torch_float).all()
    assert pyTorch_onnx_match, "ONNX and PyTorch upsampling output don't match."

    # Prep model for execution
    model = ModelWrapper(export_path)
    # model = model.transform(TransposeUpsampleIO())
    model = model.transform(MakeInputChannelsLast())
    model = model.transform(InferDataLayouts())
    model = model.transform(absorb.AbsorbTransposeIntoResize())
    model = model.transform(InferShapes())
    model = model.transform(ForceDataTypeForTensors(dType=dt))
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(InferUpsample())
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())

    # Check that all nodes are UpsampleNearestNeighbour_Batch nodes
    for n in model.get_finn_nodes():
        node_check = n.op_type == "UpsampleNearestNeighbour_Batch"
        assert node_check, "All nodes should be UpsampleNearestNeighbour_Batch nodes."

    # Prep sim
    if exec_mode == "cppsim":
        model = model.transform(PrepareCppSim())
        model = model.transform(CompileCppSim())
        model = model.transform(SetExecMode("cppsim"))
    elif exec_mode == "rtlsim":
        model = model.transform(GiveUniqueNodeNames())
        model = model.transform(PrepareIP("xc7z020clg400-1", 10))
        model = model.transform(HLSSynthIP())
        model = model.transform(SetExecMode("rtlsim"))
        model = model.transform(PrepareRTLSim())
    else:
        raise Exception("Unknown exec_mode")

    # Run sim
    test_in_transposed = test_in.numpy().transpose(_to_chan_last_args)
    input_dict = {model.graph.input[0].name: test_in_transposed}
    output_dict = oxe.execute_onnx(model, input_dict, True)
    test_result = output_dict[model.graph.output[0].name]
    output_matches = np.isclose(golden_result, test_result, atol=atol).all()

    if exec_mode == "cppsim":
        assert output_matches, "Cppsim output doesn't match ONNX/PyTorch."
    elif exec_mode == "rtlsim":
        assert output_matches, "Rtlsim output doesn't match ONNX/PyTorch."
Exemplo n.º 29
0
    def forward_part1(self, x, mask_matrix):
        x_shape = x.size()
        x = self.norm1(x)
        if len(x_shape) == 5:
            b, d, h, w, c = x.shape
            window_size, shift_size = get_window_size(
                (d, h, w), self.window_size, self.shift_size)
            pad_l = pad_t = pad_d0 = 0
            pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
            pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
            pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
            _, dp, hp, wp, _ = x.shape
            dims = [b, dp, hp, wp]

        elif len(x_shape) == 4:
            b, h, w, c = x.shape
            window_size, shift_size = get_window_size((h, w), self.window_size,
                                                      self.shift_size)
            pad_l = pad_t = 0
            pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
            pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
            _, hp, wp, _ = x.shape
            dims = [b, hp, wp]

        if any(i > 0 for i in shift_size):
            if len(x_shape) == 5:
                shifted_x = torch.roll(x,
                                       shifts=(-shift_size[0], -shift_size[1],
                                               -shift_size[2]),
                                       dims=(1, 2, 3))
            elif len(x_shape) == 4:
                shifted_x = torch.roll(x,
                                       shifts=(-shift_size[0], -shift_size[1]),
                                       dims=(1, 2))
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None
        x_windows = window_partition(shifted_x, window_size)
        attn_windows = self.attn(x_windows, mask=attn_mask)
        attn_windows = attn_windows.view(-1, *(window_size + (c, )))
        shifted_x = window_reverse(attn_windows, window_size, dims)
        if any(i > 0 for i in shift_size):
            if len(x_shape) == 5:
                x = torch.roll(shifted_x,
                               shifts=(shift_size[0], shift_size[1],
                                       shift_size[2]),
                               dims=(1, 2, 3))
            elif len(x_shape) == 4:
                x = torch.roll(shifted_x,
                               shifts=(shift_size[0], shift_size[1]),
                               dims=(1, 2))
        else:
            x = shifted_x

        if len(x_shape) == 5:
            if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
                x = x[:, :d, :h, :w, :].contiguous()
        elif len(x_shape) == 4:
            if pad_r > 0 or pad_b > 0:
                x = x[:, :h, :w, :].contiguous()

        return x
Exemplo n.º 30
0
def rollrow():
    return lambda x, shift: torch.roll(x, shift, 0)