コード例 #1
0
    def __init__(self, output_shape, name='Full Network'):
        """
        Initializes the Full Network.
        :param output_shape: (5-tuple) The desired output shape for generated videos. Must match video input shape.
                              Legal values: (bsz, 3, 8, 112, 112) and (bsz, 3, 16, 112, 112)
        :param name: (str, optional) The name of the network (default 'Full Network').
        Raises:
            ValueError: if 'output_shape' does not contain a legal number of frames.
        """
        if output_shape[2] not in self.VALID_FRAME_COUNTS:
            raise ValueError('Invalid number of frames in desired output: %d' %
                             output_shape[2])

        super(FullNetwork, self).__init__()

        self.net_name = name
        self.output_shape = output_shape

        # self.vgg = VGG('VGG16')
        self.vgg = vgg16()
        self.vgg.load_state_dict(torch.load(vgg_weights_path))
        self.i3d = InceptionI3d(final_endpoint='Mixed_5c',
                                in_frames=output_shape[2])
        self.i3d.load_state_dict(torch.load(i3d_weights_path))
        self.gen = Generator(in_channels=1536, out_frames=output_shape[2])
コード例 #2
0
    def __init__(self, vp_value_count, output_shape, name='Full Network'):
        """
        Initializes the Full Network.
        :param output_shape: (5-tuple) The desired output shape for generated videos. Must match video input shape.
                              Legal values: (bsz, 3, 8, 112, 112) and (bsz, 3, 16, 112, 112)
        :param name: (str, optional) The name of the network (default 'Full Network').
        Raises:
            ValueError: if 'vp_value_count' is not a legal value count
            ValueError: if 'output_shape' does not contain a legal number of frames.
        """
        if vp_value_count not in self.VALID_VP_VALUE_COUNTS:
            raise ValueError('Invalid number of vp values: %d' % vp_value_count)
        if output_shape[2] not in self.VALID_FRAME_COUNTS:
            raise ValueError('Invalid number of frames in desired output: %d' % output_shape[2])

        super(FullNetwork, self).__init__()

        self.net_name = name
        self.vp_value_count = vp_value_count
        self.output_shape = output_shape
        self.out_frames = output_shape[2]

        self.vgg = vgg16(pretrained=True, weights_path=vgg_weights_path)
        self.i3d = InceptionI3d(final_endpoint='Mixed_5c_small', in_frames=self.out_frames,
                                pretrained=True, weights_path=i3d_weights_path)

        self.deconv = Deconv(in_channels=256, out_frames=self.out_frames)
        self.exp = Expander(vp_value_count=self.vp_value_count, out_frames=self.out_frames, out_size=28)
        self.trans = Transformer(in_channels=32 + self.vp_value_count)

        self.gen = Generator(in_channels=32 + 32, out_frames=self.out_frames)