예제 #1
0
    def forward(self, x):
        splits = iter(x.split(self.split_size, dim=0))
        s_next = next(splits)
        x = self.seq1(s_next).to('cuda:0')
        context = None
        with dni.synthesizer_context(context):
            # print("enter")
            s_prev = self.backward_interface(x.to('cuda:0')).to('cuda:1')

        ret = []

        for s_next in splits:
            # A. s_prev runs on cuda:1
            s_prev = self.seq2(s_prev)
            ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

            # B. s_next runs on cuda:0, which can run concurrently with A
            x = self.seq1(s_next).to('cuda:0')
            context = None
            with dni.synthesizer_context(context):
                # print("enter")
                s_prev = self.backward_interface(x.to('cuda:0')).to('cuda:1')

        s_prev = self.seq2(s_prev)
        ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

        return torch.cat(ret)
예제 #2
0
 def forward(self, x, y=None):
     x = self.conv1(x)  # reshape channels from 1 to 3 for vgg
     x = self.block1(x)
     if self.training:
         context = one_hot(y, 10)
         with dni.synthesizer_context(context):
             x = self.backward_interface_1(x)
     x = self.block2(x)
     # if self.training:
     #     context = one_hot(y, 10)
     #     with dni.synthesizer_context(context):
     #         x = self.backward_interface_2(x)
     x = self.block3(x)
     # if self.training:
     #     context = one_hot(y, 10)
     #     with dni.synthesizer_context(context):
     #         x = self.backward_interface_3(x)
     x = self.block4(x)
     # if self.training:
     #     context = one_hot(y, 10)
     #     with dni.synthesizer_context(context):
     #         x = self.backward_interface_4(x)
     x = self.block5(x)
     # if self.training:
     #     context = one_hot(y, 10)
     #     with dni.synthesizer_context(context):
     #         x = self.backward_interface_5(x)
     x = x.view(x.size(0), -1)
     x = self.classifier(x)
     return F.log_softmax(x, dim=1)
예제 #3
0
    def forward(self, x):
        x = self.seq1(x)
        # print(x.shape)
        context = None
        with dni.synthesizer_context(context):
            x = self.backward_interface(x)

        return x.view(x.size(0), -1)
예제 #4
0
 def forward(self, x, y=None, epoch=None, dni_delay=None):
     if args.dni and self.training:
         if (epoch is None or dni_delay is None) or epoch > dni_delay:
             if args.context:
                 context = one_hot(y, 10)
             else:
                 context = None
             with dni.synthesizer_context(context):
                 x = self.backward_interface(x)
     return x
예제 #5
0
    def forward(self, x, y=None):
        x = self.block_1(x)
        if self.training:
            context = one_hot(y, 10)
            with dni.synthesizer_context(context):
                x = self.backward_interface_1(x)
        x = x.view(x.size(0), -1)

        x = self.classifier(x)

        return F.log_softmax(x, dim=1)
예제 #6
0
 def forward(self, x, y=None):
     input_flat = x.view(x.size()[0], -1)
     x = self.hidden1_bn(self.hidden1(input_flat))
     x = self.hidden2_bn(self.hidden2(F.relu(x)))
     if args.dni and self.training:
         if args.context:
             context = one_hot(y, 10)
         else:
             context = None
         with dni.synthesizer_context(context):
             x = self.bidirectional_interface(x, input_flat)
     x = self.output_bn(self.output(F.relu(x)))
     return F.log_softmax(x)
예제 #7
0
 def forward(self, x, y=None):
     x = x.view(x.size()[0], -1)
     x = self.hidden1_bn(self.hidden1(x))
     x = self.hidden2_bn(self.hidden2(F.relu(x)))
     if self.use_dni and self.training:
         if self.context:
             context = one_hot(y, 10, self.device)
         else:
             context = None
         with dni.synthesizer_context(context):
             x = self.backward_interface(x)
     x = self.output_bn(self.output(F.relu(x)))
     return F.log_softmax(x, dim=1)
예제 #8
0
    def forward(self, x, y=None):
        for i in range(3):
            x = self.block[i](x)
            if args.dni and self.training:
                if args.context:
                    context = one_hot(y, 10)
                else:
                    context = None
                with dni.synthesizer_context(context):
                    x = self.backward_interfaces[i](x)

        x = x.view(-1, 128)
        x = self.fc1(x)
        return F.log_softmax(x)
예제 #9
0
    def forward(self, x):
        x = self.seq1(x)
        # print(x.shape)
        context = None
        with dni.synthesizer_context(context):
            x = self.backward_interface(x)

        # x = self.input_trigger(x)
        # x = self.hidden(x)
        # x = self.output(x)

        x = self.seq2(x.to('cuda:0'))

        return self.fc(x.view(x.size(0), -1))
예제 #10
0
 def forward(self, x, y=None):
     x = F.relu(F.max_pool2d(self.conv1_bn(self.conv1(x)), 2))
     x = F.max_pool2d(self.conv2_drop(self.conv2_bn(self.conv2(x))), 2)
     if args.dni and self.training:
         if args.context:
             context = one_hot(y, 10)
         else:
             context = None
         with dni.synthesizer_context(context):
             x = self.backward_interface(x)
     x = F.relu(x)
     x = x.view(-1, 320)
     x = F.relu(self.fc1_bn(self.fc1(x)))
     x = F.dropout(x, training=self.training)
     x = self.fc2_bn(self.fc2(x))
     return F.log_softmax(x)
예제 #11
0
    def forward(self, x, y=None, epoch=None, dni_delay=None):
        verbose = np.random.random() < 0.01

        if args.save_space:
            self.seq2.cpu()
            torch.cuda.empty_cache()

            if verbose:
                torch.cuda.reset_max_memory_allocated()

            self.seq1.cuda()
        
        x = self.seq1(x)

        if args.dni and self.training:
            if (epoch is None or dni_delay is None) or epoch > dni_delay:
                if args.context:
                    context = one_hot(y, 10)
                else:
                    context = None
                with dni.synthesizer_context(context):
                    x = self.backward_interface(x)
        
        if verbose:
            print("with seq 1, GPU mem:", torch.cuda.max_memory_allocated())

        if args.save_space:
            self.seq1.cpu()
            torch.cuda.empty_cache()

            if verbose:
                torch.cuda.reset_max_memory_allocated()

            self.seq2.cuda()
            
        x = self.seq2(x)  # .to('cuda:0'))
        
        if verbose:
            print("with seq 2, GPU mem:", torch.cuda.max_memory_allocated())

        x = self.fc(x.view(x.size(0), -1))    
        return F.log_softmax(x)