def __init__(self, args):
        """
        :param self.train_frames: how many frame are used for training
        :param self.num_of_vid: number of videos
        :param self.extractor: pretrained model for feature extraction
        :param self.G: what I want to training model
        :param self.optimizerG: optimizer (default: Adam)
        :param self.train_data: training data loader
        :param self.device: gpu or cpu device(default: gpu)
        """
        self.train_frames = args.train_frames
        self.num_of_vid = args.num_of_vid
        self.crop_size = args.crop_size
        self.train_frames = args.train_frames
        self.pretrain_path = pretrain_path
        print("===> Building model")
        # Model Setting
        # prtrained model setting
        self.extractor = resnet101(num_classes=400,
                                   shortcut_type='B',
                                   cardinality=32,
                                   sample_size=self.crop_size,
                                   sample_duration=self.train_frames)

        # load pretrained model
        weight = get_pretrain_weight(self.pretrain_path, self.extractor)
        self.extractor.load_state_dict(weight)
        self.extractor.eval()

        self.G = Generator(shortcut_type='B',
                           cardinality=32,
                           sample_size=self.crop_size,
                           frame_num=self.train_frames,
                           num_of_vid=self.num_of_vid)
        self.G.train()

        # optimizer
        self.optimizerG = optim.Adam(self.G.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8)
        # loss
        self.CE_loss = nn.CrossEntropyLoss()
        self.L1_loss = nn.L1Loss()
        self.L2_loss = nn.MSELoss()
        self.Softmax = nn.Softmax()
        # data
        self.train_data = train_data_loader
        # gpu
        self.device0 = device0
        self.device1 = device1

        self.epochs = args.epochs
        self.avg_G_loss_arr = []
        self.checkpoint = args.checkpoint
        self.checkpoint_set()
        self.lr_decay = args.lr_decay

        # cuda
        if torch.cuda.is_available():
            self.extractor.to(self.device0)
            self.G.cuda(self.device1)
Ejemplo n.º 2
0
    def __init__(self, args):
        self.test_frame = args.test_frame
        self.num_of_vid = args.num_of_vid
        self.crop_size = args.crop_size
        self.test_frame = args.test_frame
        print("===> Building model")
        # prtrained model setting
        self.extractor = resnet101(num_classes=400,
                                   shortcut_type='B',
                                   cardinality=32,
                                   sample_size=self.crop_size,
                                   sample_duration=self.test_frame)

        # load pretrained model with eval mode
        weight = get_pretrain_weight(args.pretrained_model_path,
                                     self.extractor)
        self.extractor.load_state_dict(weight)
        self.extractor.eval()

        self.G = torch.load(args.train_model)
        self.G.eval()

        # data loadeer
        self.test_data = test_data_loader

        # cuda GPU devices setting
        self.device0 = device0
        self.device1 = device1
        if torch.cuda.is_available():
            self.extractor.cuda(self.device1)
            self.G.cuda(self.device1)

        self.log_dir = args.log_path
        self.original_vid_dir = args.original_vid_dir
        self.demo_dir = args.demo_dir
        self.out_fps = args.out_fps
        self.height = args.height
        self.width = args.width
        self.in_fps = args.fps