def __init__(self, output_shape, name='Full Network'): """ Initializes the Full Network. :param output_shape: (5-tuple) The desired output shape for generated videos. Must match video input shape. Legal values: (bsz, 3, 8, 112, 112) and (bsz, 3, 16, 112, 112) :param name: (str, optional) The name of the network (default 'Full Network'). Raises: ValueError: if 'output_shape' does not contain a legal number of frames. """ if output_shape[2] not in self.VALID_FRAME_COUNTS: raise ValueError('Invalid number of frames in desired output: %d' % output_shape[2]) super(FullNetwork, self).__init__() self.net_name = name self.output_shape = output_shape # self.vgg = VGG('VGG16') self.vgg = vgg16() self.vgg.load_state_dict(torch.load(vgg_weights_path)) self.i3d = InceptionI3d(final_endpoint='Mixed_5c', in_frames=output_shape[2]) self.i3d.load_state_dict(torch.load(i3d_weights_path)) self.gen = Generator(in_channels=1536, out_frames=output_shape[2])
def __init__(self, vp_value_count, output_shape, name='Full Network'): """ Initializes the Full Network. :param output_shape: (5-tuple) The desired output shape for generated videos. Must match video input shape. Legal values: (bsz, 3, 8, 112, 112) and (bsz, 3, 16, 112, 112) :param name: (str, optional) The name of the network (default 'Full Network'). Raises: ValueError: if 'vp_value_count' is not a legal value count ValueError: if 'output_shape' does not contain a legal number of frames. """ if vp_value_count not in self.VALID_VP_VALUE_COUNTS: raise ValueError('Invalid number of vp values: %d' % vp_value_count) if output_shape[2] not in self.VALID_FRAME_COUNTS: raise ValueError('Invalid number of frames in desired output: %d' % output_shape[2]) super(FullNetwork, self).__init__() self.net_name = name self.vp_value_count = vp_value_count self.output_shape = output_shape self.out_frames = output_shape[2] self.vgg = vgg16(pretrained=True, weights_path=vgg_weights_path) self.i3d = InceptionI3d(final_endpoint='Mixed_5c_small', in_frames=self.out_frames, pretrained=True, weights_path=i3d_weights_path) self.deconv = Deconv(in_channels=256, out_frames=self.out_frames) self.exp = Expander(vp_value_count=self.vp_value_count, out_frames=self.out_frames, out_size=28) self.trans = Transformer(in_channels=32 + self.vp_value_count) self.gen = Generator(in_channels=32 + 32, out_frames=self.out_frames)