Exemplo n.º 1
0
def load_checkpoints(config, checkpoint, blend_scale=0.125, first_order_motion_model=False, cpu=False):
    with open(config) as f:
        config = yaml.load(f)

    reconstruction_module = PartSwapGenerator(blend_scale=blend_scale,
                                              first_order_motion_model=first_order_motion_model,
                                              **config['model_params']['reconstruction_module_params'],
                                              **config['model_params']['common_params'])

    if not cpu:
        reconstruction_module.cuda()

    segmentation_module = SegmentationModule(**config['model_params']['segmentation_module_params'],
                                             **config['model_params']['common_params'])
    if not cpu:
        segmentation_module.cuda()

    if cpu:
        checkpoint = torch.load(checkpoint, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint)

    load_reconstruction_module(reconstruction_module, checkpoint)
    load_segmentation_module(segmentation_module, checkpoint)

    if not cpu:
        reconstruction_module = DataParallelWithCallback(reconstruction_module)
        segmentation_module = DataParallelWithCallback(segmentation_module)

    reconstruction_module.eval()
    segmentation_module.eval()

    return reconstruction_module, segmentation_module
Exemplo n.º 2
0
def load_checkpoints(config_path, checkpoint_path, cpu=False):

    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    if not cpu:
        generator.cuda()

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if not cpu:
        kp_detector.cuda()
    
    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path)
 
    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])
    
    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()
    
    return generator, kp_detector
def load_checkpoints(config_path, checkpoint_path, cpu=False):
    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(
        **config["model_params"]["generator_params"],
        **config["model_params"]["common_params"],
    )
    if cpu:
        generator.cpu()
    else:
        generator.cuda()

    kp_detector = KPDetector(
        **config["model_params"]["kp_detector_params"],
        **config["model_params"]["common_params"],
    )
    if cpu:
        kp_detector.cpu()
    else:
        kp_detector.cuda()

    checkpoint = torch.load(checkpoint_path, map_location="cpu" if cpu else None)
    generator.load_state_dict(checkpoint["generator"])
    kp_detector.load_state_dict(checkpoint["kp_detector"])

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
Exemplo n.º 4
0
 def cuda(self):
     self.model.cuda()
     if self.get_param('network.sync_bn', False):
         self.model = DataParallelWithCallback(self.model,
                                               dim=self.batch_axis)
     else:
         self.model = nn.DataParallel(self.model, dim=self.batch_axis)
Exemplo n.º 5
0
def load_checkpoints(config_path, checkpoint_path, device='cuda'):

    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(
        **config['model_params']['generator_params'],
        **config['model_params']['common_params'])
    generator.to(device)

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    kp_detector.to(device)

    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
Exemplo n.º 6
0
def load_checkpoints(config_path, checkpoint_path, device="cuda"):

    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    generator = OcclusionAwareGenerator(
        **config["model_params"]["generator_params"],
        **config["model_params"]["common_params"],
    )
    generator.to(device)

    kp_detector = KPDetector(
        **config["model_params"]["kp_detector_params"],
        **config["model_params"]["common_params"],
    )
    kp_detector.to(device)

    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint["generator"])
    kp_detector.load_state_dict(checkpoint["kp_detector"])

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
Exemplo n.º 7
0
    def testSyncBatchNormSyncEval(self):
        bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
        sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])

        bn.cuda()
        sync_bn.cuda()

        self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
Exemplo n.º 8
0
    def testSyncBatchNorm2DSyncTrain(self):
        bn = nn.BatchNorm2d(10)
        sync_bn = SynchronizedBatchNorm2d(10)
        sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])

        bn.cuda()
        sync_bn.cuda()

        self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
    def testSyncBatchNorm2DSyncTrain(self):
        bn = nn.BatchNorm2d(10)
        sync_bn = SynchronizedBatchNorm2d(10)
        sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])

        bn.cuda()
        sync_bn.cuda()

        self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
    def testSyncBatchNormSyncEval(self):
        bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
        sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])

        bn.cuda()
        sync_bn.cuda()

        self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
Exemplo n.º 11
0
def transfer(config, generator, kp_detector, checkpoint, log_dir, dataset):
    log_dir = os.path.join(log_dir, 'transfer')
    png_dir = os.path.join(log_dir, 'png')
    transfer_params = config['transfer_params']

    dataset = PairedDataset(initial_dataset=dataset,
                            number_of_pairs=transfer_params['num_pairs'])
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='transfer'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            x = {
                key: value if not hasattr(value, 'cuda') else value.cuda()
                for key, value in x.items()
            }
            driving_video = x['driving_video']
            source_image = x['source_video'][:, :, :1, :, :]
            out = transfer_one(generator, kp_detector, source_image,
                               driving_video, transfer_params)
            img_name = "-".join([x['driving_name'][0], x['source_name'][0]])

            # Store to .png for evaluation
            out_video_batch = out['video_prediction'].data.cpu().numpy()
            out_video_batch = np.concatenate(np.transpose(
                out_video_batch, [0, 2, 3, 4, 1])[0],
                                             axis=1)
            imageio.imsave(os.path.join(png_dir, img_name + '.png'),
                           (255 * out_video_batch).astype(np.uint8))

            image = Visualizer(
                **config['visualizer_params']).visualize_transfer(
                    driving_video=driving_video,
                    source_image=source_image,
                    out=out)
            imageio.mimsave(
                os.path.join(log_dir, img_name + transfer_params['format']),
                image)
Exemplo n.º 12
0
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    loss_list = []
    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            predictions = []
            visualizations = []
            if torch.cuda.is_available():
                x['video'] = x['video'].cuda()
            kp_source = kp_detector(x['video'][:, :, 0])
            for frame_idx in range(x['video'].shape[2]):
                source = x['video'][:, :, 0]
                driving = x['video'][:, :, frame_idx]
                kp_driving = kp_detector(driving)
                out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
                out['kp_source'] = kp_source
                out['kp_driving'] = kp_driving
                del out['sparse_deformed']
                predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

                visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
                                                                                    driving=driving, out=out)
                visualizations.append(visualization)

                loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())

            predictions = np.concatenate(predictions, axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))

            image_name = x['name'][0] + config['reconstruction_params']['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)

    print("Reconstruction loss: %s" % np.mean(loss_list))
Exemplo n.º 13
0
def load_checkpoints(config_path, checkpoint_path, cpu=False):
    with open(config_path) as f:
        config = yaml.load(f)

    generator = Generator(num_regions=config['model_params']['num_regions'],
                          num_channels=config['model_params']['num_channels'],
                          **config['model_params']['generator_params'])
    if not cpu:
        generator.cuda()

    region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],
                                       num_channels=config['model_params']['num_channels'],
                                       estimate_affine=config['model_params']['estimate_affine'],
                                       **config['model_params']['region_predictor_params'])
    if not cpu:
        region_predictor.cuda()

    avd_network = AVDNetwork(num_regions=config['model_params']['num_regions'],
                             **config['model_params']['avd_network_params'])
    if not cpu:
        avd_network.cuda()

    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path)

    generator.load_state_dict(checkpoint['generator'])
    region_predictor.load_state_dict(checkpoint['region_predictor'])
    if 'avd_network' in checkpoint:
        avd_network.load_state_dict(checkpoint['avd_network'])

    if not cpu:
        generator = DataParallelWithCallback(generator)
        region_predictor = DataParallelWithCallback(region_predictor)
        avd_network = DataParallelWithCallback(avd_network)

    generator.eval()
    region_predictor.eval()
    avd_network.eval()

    return generator, region_predictor, avd_network
Exemplo n.º 14
0
    def __init__(self, opt):
        self.opt = opt
        self.pix2pix_model = Pix2PixModel(opt)
        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr
Exemplo n.º 15
0
    def __init__(self,
                 hidden_dim,
                 lr,
                 hard_or_full_trip,
                 margin,
                 num_workers,
                 batch_size,
                 restore_iter,
                 total_iter,
                 save_name,
                 train_pid_num,
                 frame_num,
                 model_name,
                 train_source,
                 test_source,
                 img_size=64):

        self.save_name = save_name
        self.train_pid_num = train_pid_num
        self.train_source = train_source
        self.test_source = test_source

        self.hidden_dim = hidden_dim
        self.lr = lr
        self.hard_or_full_trip = hard_or_full_trip
        self.margin = margin
        self.frame_num = frame_num
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.model_name = model_name
        self.P, self.M = batch_size

        self.restore_iter = restore_iter
        self.total_iter = total_iter

        self.img_size = img_size

        self.encoder = SetNet(self.hidden_dim).float()
        self.encoder = DataParallelWithCallback(self.encoder)
        self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
        self.triplet_loss = DataParallelWithCallback(self.triplet_loss)
        self.encoder.cuda()
        self.triplet_loss.cuda()

        self.optimizer = optim.Adam([
            {'params': self.encoder.parameters()},
        ], lr=self.lr)

        self.hard_loss_metric = []
        self.full_loss_metric = []
        self.full_loss_num = []
        self.dist_list = []
        self.mean_dist = 0.01

        self.sample_type = 'all'
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
        net (network)      -- the network to be initialized
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        gain (float)       -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Return an initialized network.
    """
    if len(gpu_ids) > 0:
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = DataParallelWithCallback(net, gpu_ids)  # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    return net
Exemplo n.º 17
0
def load_checkpoints(kp_detector_path, generator_path, cpu=False):

    generator = torch.jit.load(generator_path)

    kp_detector = torch.jit.load(kp_detector_path)

    if not cpu:
        generator.cuda()

    if not cpu:
        kp_detector.cuda()

    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
Exemplo n.º 18
0
def main():
    net = Baseline(num_classes=culane.num_classes,
                   deep_base=args['deep_base']).cuda().train()
    net = DataParallelWithCallback(net)

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['base_lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['base_lr']
    }],
                          momentum=args['momentum'])

    if len(args['checkpoint']) > 0:
        print('training resumes from \'%s\'' % args['checkpoint'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['checkpoint'] + '_checkpoint.pth')))
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['checkpoint'] + '_checkpoint_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['base_lr']
        optimizer.param_groups[1]['lr'] = args['base_lr']

    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')

    train(net, optimizer)
Exemplo n.º 19
0
    def restore(self, checkpoint):
        self.epoch = 0

        self.generator = DCGenerator(**self.config['generator_params'])
        self.generator = DataParallelWithCallback(self.generator,
                                                  device_ids=self.device_ids)
        self.optimizer_generator = torch.optim.Adam(
            params=self.generator.parameters(),
            lr=self.config['lr_generator'],
            betas=(self.config['b1_generator'], self.config['b2_generator']),
            weight_decay=0,
            eps=1e-8)

        self.discriminator = DCDiscriminator(
            **self.config['discriminator_params'])
        self.discriminator = DataParallelWithCallback(
            self.discriminator, device_ids=self.device_ids)
        self.optimizer_discriminator = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=self.config['lr_discriminator'],
            betas=(self.config['b1_discriminator'],
                   self.config['b2_discriminator']),
            weight_decay=0,
            eps=1e-8)

        if checkpoint is not None:
            data = torch.load(checkpoint)
            for key, value in data:
                if key == 'epoch':
                    self.epoch = value
                else:
                    self.__dict__[key].load_state_dict(value)

        lr_lambda = lambda epoch: 1 - epoch / self.config['num_epochs']
        self.scheduler_generator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_generator, lr_lambda, last_epoch=self.epoch - 1)
        self.scheduler_discriminator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_discriminator, lr_lambda, last_epoch=self.epoch - 1)
Exemplo n.º 20
0
parser.add_argument('--train', action='store_true', help='train the model')
parser.add_argument('--test', action='store_true', help='test the model')
parser.add_argument('--overlap_rate', type=float, default=0.25, help='the overlap rate of the overlap coherence training scheme')
parser.add_argument('--lambdaa', type=float, default=0.0, help='weight of the overlap coherence loss')

opt = parser.parse_args()
print(opt)

torch.cuda.set_device(opt.device_id[0])

# ######################## Module #################################
print('Building model')
model = actionModel(opt.class_num, batch_norm=True, dropout=opt.dropout, TD_rate=opt.TD_rate, image_size=opt.img_size, syn_bn=opt.syn_bn, test_scheme=3)
print(model)
if opt.syn_bn:
    model = DataParallelWithCallback(model, device_ids=opt.device_id).cuda()
else:
    model = torch.nn.DataParallel(model, device_ids=opt.device_id).cuda()
print("Channels: " + str(model.module.channels))

# ########################Optimizer#########################
optimizer = torch.optim.SGD([{'params': model.module.RNN.parameters(), 'lr': opt.LR[0]},
                             {'params': model.module.ShortCut.parameters(), 'lr': opt.LR[0]},
                             {'params': model.module.classifier.parameters(), 'lr': opt.LR[1]}
                             ], lr=opt.LR[1], weight_decay=opt.weight_decay, momentum=0.9)

# ###################### Loss Function ####################################
loss_classification_func = nn.NLLLoss(reduce=True)


def loss_overlap_coherence_func(pre, cur):
Exemplo n.º 21
0
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
    log_dir = os.path.join(log_dir, 'animation')
    png_dir = os.path.join(log_dir, 'png')
    animate_params = config['animate_params']

    dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='animate'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            predictions = []
            visualizations = []

            driving_video = x['driving_video']
            source_frame = x['source_video'][:, :, 0, :, :]

            kp_source = kp_detector(source_frame)
            kp_driving_initial = kp_detector(driving_video[:, :, 0])

            for frame_idx in range(driving_video.shape[2]):
                driving_frame = driving_video[:, :, frame_idx]
                kp_driving = kp_detector(driving_frame)
                kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                       kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
                out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)

                out['kp_driving'] = kp_driving
                out['kp_source'] = kp_source
                out['kp_norm'] = kp_norm

                del out['sparse_deformed']

                predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

                visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
                                                                                    driving=driving_frame, out=out)
                visualization = visualization
                visualizations.append(visualization)

            predictions = np.concatenate(predictions, axis=1)
            result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
            imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))

            image_name = result_name + animate_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
Exemplo n.º 22
0
def train(config,
          generator,
          discriminator,
          kp_detector,
          checkpoint,
          log_dir,
          dataset,
          device_ids,
          load_weights_only=False,
          use_both=False,
          update_kp=True):
    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                           lr=train_params['lr'],
                                           betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=train_params['lr'],
                                               betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(),
                                             lr=train_params['lr'],
                                             betas=(0.5, 0.999))

    if checkpoint is not None:

        if load_weights_only:
            Logger.load_cpk(checkpoint, generator, discriminator, kp_detector)
            start_epoch = 0
            it = 0
        else:
            saved_start_epoch, saved_it = Logger.load_cpk(
                checkpoint, generator, discriminator, kp_detector,
                optimizer_generator, optimizer_discriminator,
                optimizer_kp_detector)
            start_epoch = saved_start_epoch
            it = saved_it
    else:
        start_epoch = 0
        it = 0

    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_discriminator,
                                          train_params['epoch_milestones'],
                                          gamma=0.1,
                                          last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector,
                                        train_params['epoch_milestones'],
                                        gamma=0.1,
                                        last_epoch=start_epoch - 1)

    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=True)

    generator_full = GeneratorFullModel(kp_detector, generator, discriminator,
                                        train_params)
    discriminator_full = DiscriminatorFullModel(kp_detector, generator,
                                                discriminator, train_params)

    generator_full_par = DataParallelWithCallback(generator_full,
                                                  device_ids=device_ids)
    discriminator_full_par = DataParallelWithCallback(discriminator_full,
                                                      device_ids=device_ids)

    if not os.path.isdir(log_dir + 'tb_log/'):
        os.mkdir(log_dir + 'tb_log/')

    writer = SummaryWriter(log_dir=log_dir + 'tb_log/')

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                **train_params['log_params']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):

            total_discriminator_loss = 0.
            total_generator_loss = 0.

            for i, x in enumerate(dataloader):

                # driving = driving_A, source = src_B --> driving_B'
                mn1_dict = {}
                mn1_dict['source'] = x['src_B']
                mn1_dict['video'] = x['driving_A']
                mn1_dict['gt_video'] = x['driving_B']
                '''This code is for the first model where we use ground truth key points for both'''
                if not use_both:
                    loss_values_B, loss_B, generated_B, gt_kp_joined_B = compute_loss(
                        generator_full_par, mn1_dict)
                else:
                    '''This code is for the second model where we give GT and approx kp respectively '''
                    out_B = compute_loss(generator_full_par,
                                         mn1_dict,
                                         use_both=True)
                    loss_values_B, loss_B, generated_B, approx_kp_joined_B, gt_kp_joined_B = out_B

                # driving = generated_B (driving_B'), source = src_A --> driving_A'
                mn2_dict = {}
                mn2_dict['source'] = x['src_A']
                mn2_dict['video'] = generated_B['video_prediction']
                mn2_dict['gt_video'] = x['driving_A']

                # First model - see above
                if not use_both:
                    loss_values_A, loss_A, generated_A, gt_kp_joined_A = compute_loss(
                        generator_full_par, mn2_dict)
                else:
                    # Second model - see above
                    out_A = compute_loss(generator_full_par,
                                         mn2_dict,
                                         use_both=True)
                    loss_values_A, loss_A, generated_A, approx_kp_joined_A, gt_kp_joined_A = out_A

                loss = loss_B + loss_A
                total_generator_loss += loss

                loss = loss_B
                total_generator_loss = loss

                writer.add_scalar('generator loss B', loss_B.item(), it)
                writer.add_scalar('generator loss A', loss_A.item(), it)
                writer.add_scalar('generator loss', loss.item(), it)

                loss.backward(
                    retain_graph=not train_params['detach_kp_discriminator'])
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_discriminator.zero_grad()

                if train_params['detach_kp_discriminator'] and update_kp:
                    optimizer_kp_detector.step()
                    optimizer_kp_detector.zero_grad()

                generator_loss_values = {}
                generator_loss_values['A'] = [
                    val.detach().cpu().numpy() for val in loss_values_A
                ]
                generator_loss_values['B'] = [
                    val.detach().cpu().numpy() for val in loss_values_B
                ]

                if not use_both:
                    loss_values_B = discriminator_full_par(
                        mn1_dict, gt_kp_joined_B, generated_B)
                    loss_values_A = discriminator_full_par(
                        mn2_dict, gt_kp_joined_A, generated_A)
                else:
                    loss_values_B = discriminator_full_par(
                        mn1_dict, approx_kp_joined_B, generated_B,
                        gt_kp_joined_B)
                    loss_values_A = discriminator_full_par(
                        mn2_dict, approx_kp_joined_A, generated_A,
                        gt_kp_joined_A)

                loss_values_B = [val.mean() for val in loss_values_B]
                loss_values_A = [val.mean() for val in loss_values_A]

                loss_B = sum(loss_values_B)
                loss_A = sum(loss_values_A)

                loss = loss_A + loss_B
                total_discriminator_loss += loss

                loss = loss_B
                total_discriminator_loss = loss

                writer.add_scalar('disc loss B', loss_B.item(), it)
                writer.add_scalar('disc loss A', loss_A.item(), it)
                writer.add_scalar('disc loss', loss.item(), it)

                loss.backward()
                optimizer_discriminator.step()
                optimizer_discriminator.zero_grad()
                if not train_params['detach_kp_discriminator'] and update_kp:
                    optimizer_kp_detector.step()
                    optimizer_kp_detector.zero_grad()

                discriminator_loss_values = {}
                discriminator_loss_values['A'] = [
                    val.detach().cpu().numpy() for val in loss_values_A
                ]
                discriminator_loss_values['B'] = [
                    val.detach().cpu().numpy() for val in loss_values_B
                ]

                values = {
                    'A':
                    generator_loss_values['A'] +
                    discriminator_loss_values['A'],
                    'B':
                    generator_loss_values['B'] + discriminator_loss_values['B']
                }

                logger.log_iter(
                    it,
                    names=generator_loss_names(train_params['loss_weights']) +
                    discriminator_loss_names(),
                    values=values['B'],
                    inp=mn1_dict,
                    out=generated_B,
                    name='src_B_driving_A')
                logger.log_iter(
                    it,
                    names=generator_loss_names(train_params['loss_weights']) +
                    discriminator_loss_names(),
                    values=values['A'],
                    inp=mn2_dict,
                    out=generated_A,
                    name='src_A_driving_B')
                it += 1

            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()

            writer.add_scalar('generator loss / train',
                              total_generator_loss / (i + 1), epoch)
            writer.add_scalar('discriminator loss / train',
                              total_discriminator_loss / (i + 1), epoch)

            logger.log_epoch(
                epoch, {
                    'generator': generator,
                    'discriminator': discriminator,
                    'kp_detector': kp_detector,
                    'optimizer_generator': optimizer_generator,
                    'optimizer_discriminator': optimizer_discriminator,
                    'optimizer_kp_detector': optimizer_kp_detector
                })
Exemplo n.º 23
0
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir,
          dataset, device_ids):
    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                           lr=train_params['lr'],
                                           betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=train_params['lr'],
                                               betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(),
                                             lr=train_params['lr'],
                                             betas=(0.5, 0.999))

    if checkpoint is not None:
        start_epoch, it = Logger.load_cpk(checkpoint, generator, discriminator,
                                          kp_detector, optimizer_generator,
                                          optimizer_discriminator,
                                          optimizer_kp_detector)
    else:
        start_epoch = 0
        it = 0

    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_generator,
                                          train_params['epoch_milestones'],
                                          gamma=0.1,
                                          last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(optimizer_generator,
                                        train_params['epoch_milestones'],
                                        gamma=0.1,
                                        last_epoch=start_epoch - 1)

    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=True)

    generator_full = GeneratorFullModel(kp_detector, generator, discriminator,
                                        train_params)
    discriminator_full = DiscriminatorFullModel(kp_detector, generator,
                                                discriminator, train_params)

    generator_full_par = DataParallelWithCallback(generator_full,
                                                  device_ids=device_ids)
    discriminator_full_par = DataParallelWithCallback(discriminator_full,
                                                      device_ids=device_ids)

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                **train_params['log_params']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for x in dataloader:
                out = generator_full_par(x)
                loss_values = out[:-2]
                generated = out[-2]
                kp_joined = out[-1]
                loss_values = [val.mean() for val in loss_values]
                loss = sum(loss_values)

                loss.backward(
                    retain_graph=not train_params['detach_kp_discriminator'])
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_discriminator.zero_grad()
                if train_params['detach_kp_discriminator']:
                    optimizer_kp_detector.step()
                    optimizer_kp_detector.zero_grad()

                generator_loss_values = [
                    val.detach().cpu().numpy() for val in loss_values
                ]

                loss_values = discriminator_full_par(x, kp_joined, generated)
                loss_values = [val.mean() for val in loss_values]
                loss = sum(loss_values)

                loss.backward()
                optimizer_discriminator.step()
                optimizer_discriminator.zero_grad()
                if not train_params['detach_kp_discriminator']:
                    optimizer_kp_detector.step()
                    optimizer_kp_detector.zero_grad()

                discriminator_loss_values = [
                    val.detach().cpu().numpy() for val in loss_values
                ]

                logger.log_iter(
                    it,
                    names=generator_loss_names(train_params['loss_weights']) +
                    discriminator_loss_names(),
                    values=generator_loss_values + discriminator_loss_values,
                    inp=x,
                    out=generated)
                it += 1

            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()

            logger.log_epoch(
                epoch, {
                    'generator': generator,
                    'discriminator': discriminator,
                    'kp_detector': kp_detector,
                    'optimizer_generator': optimizer_generator,
                    'optimizer_discriminator': optimizer_discriminator,
                    'optimizer_kp_detector': optimizer_kp_detector
                })
Exemplo n.º 24
0
def debug_generator(generator, kp_to_skl_gt, loader, train_params, 
                     logger, device_ids, tgt_batch=None):
    log_params = train_params['log_params']
    genModel = ConditionalGenerator2D(generator, train_params)
    genModel = DataParallelWithCallback(genModel, device_ids=device_ids)

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                            lr=train_params['lr'],
                                            betas=train_params['betas'])
    scheduler_generator = MultiStepLR(optimizer_generator, 
                                       train_params['epoch_milestones'], 
                                       gamma=0.1, last_epoch=-1)
 
    k=0
    train_views = [0,1,3]
    eval_views = [2]
    if tgt_batch is not None:
        tgt_batch_samples = split_data(tgt_batch, 
                                       train_views=train_views, 
                                       eval_views=eval_views)
        with torch.no_grad():
            tgt_batch_samples['gt_skl'] = kp_to_skl_gt(tgt_batch_samples['kps'].to('cuda')).unsqueeze(1)
            tgt_batch_samples['gt_skl_eval'] = kp_to_skl_gt(tgt_batch_samples['kps_eval'].to('cuda')).unsqueeze(1)
        
    for epoch in range(train_params['num_epochs']):
        for i, batch  in enumerate(tqdm(loader)):
            batch_samples = split_data((img, annots, ref_img), 
                                         train_views=train_views, 
                                         eval_views=eval_views)
            #imgs = flatten_views(imgs)
            #ref_imgs = flatten_views(ref_imgs)
            #ref_imgs = torch.rand(*imgs.shape)
            with torch.no_grad():
                batch_samples['gt_skl'] = kp_to_skl_gt(batch_samples['kps'].to('cuda')).unsqueeze(1)
                #batch_samples['gt_skl_eval'] = kp_to_skl_gt(batch_samples['kps_eval'].to('cuda')).unsqueeze(1)
                #gt_skl = (kp_to_skl_gt(flatten_views(annots / (ref_img.shape[3] - 1)).to('cuda'))).unsqueeze(1)
            #gt_skl = torch.rand(imgs.shape[0], 1, *imgs.shape[2:])

            #generator_out = genModel(imgs, ref_imgs, gt_skl)
            generator_out = genModel(batch_samples['imgs'], batch_samples['ref_imgs'], batch_samples['gt_skl'])
            ##### Generator update
            #loss_generator = generator_out['loss']
            loss_generator = generator_out['perceptual_loss']
            loss_generator = [x.mean() for x in loss_generator]
            loss_gen = sum(loss_generator)
            loss_gen.backward(retain_graph=not train_params['detach_kp_discriminator'])
            optimizer_generator.step()
            optimizer_generator.zero_grad()

            ########### LOG
            logger.add_scalar("Generator Loss", 
                               loss_gen.item(), 
                               epoch * len(loader) + i + 1)
            if i in log_params['log_imgs']:
                if tgt_batch is not None:
                    with torch.no_grad():
                        genModel.eval()
                        generator_out_eval = genModel(tgt_batch_samples['imgs_eval'], 
                                                      tgt_batch_samples['ref_imgs_eval'],
                                                      tgt_batch_samples['gt_skl_eval'])
                        #generator_out_eval = genModel(batch_samples['imgs_eval'], 
                        #                              batch_samples['ref_imgs_eval'],
                        #                              batch_samples['gt_skl_eval'])
                        concat_img_eval = np.concatenate((tensor_to_image(tgt_batch_samples['imgs_eval'][k]), 
                                     tensor_to_image(tgt_batch_samples['gt_skl_eval'][k]), 
                                     tensor_to_image(tgt_batch_samples['ref_imgs_eval'][k]),
                                     tensor_to_image(generator_out_eval['reconstructred_image'][k])), axis=2)  # concat along width
                        logger.add_image('Sample_{%d}_EVAL' % i, concat_img_eval, epoch)
                        genModel.train()
                k += 1
                k = k % 4
                concat_img = np.concatenate((tensor_to_image(batch_samples['imgs'][k]), 
                             tensor_to_image(batch_samples['gt_skl'][k]), 
                             tensor_to_image(batch_samples['ref_imgs'][k]),
                             tensor_to_image(generator_out['reconstructred_image'][k])), axis=2)  # concat along width
                logger.add_image('Sample_{%d}' % i, concat_img, epoch)


        scheduler_generator.step()
Exemplo n.º 25
0
def train(config, generator, mask_generator, checkpoint, log_dir, dataset,
          device_ids):
    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                           lr=train_params['lr_generator'],
                                           betas=(0.5, 0.999))
    optimizer_mask_generator = torch.optim.Adam(
        mask_generator.parameters(),
        lr=train_params['lr_mask_generator'],
        betas=(0.5, 0.999))

    if checkpoint is not None:
        print('loading cpk')
        start_epoch = Logger.load_cpk(
            checkpoint, generator, mask_generator, optimizer_generator,
            None if train_params['lr_mask_generator'] == 0 else
            optimizer_mask_generator)
    else:
        start_epoch = 0

    print(start_epoch)
    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_mask_generator = MultiStepLR(
        optimizer_mask_generator,
        train_params['epoch_milestones'],
        gamma=0.1,
        last_epoch=-1 + start_epoch * (train_params['lr_mask_generator'] != 0))

    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])
    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=6,
                            drop_last=True)

    generator_full = GeneratorFullModel(mask_generator, generator,
                                        train_params)

    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full,
                                                  device_ids=device_ids)

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for index, x in enumerate(dataloader):
                predict_mask = epoch >= 1
                losses_generator, generated = generator_full(x, predict_mask)

                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_mask_generator.step()
                optimizer_mask_generator.zero_grad()

                losses = {
                    key: value.mean().detach().data.cpu().numpy()
                    for key, value in losses_generator.items()
                }
                logger.log_iter(losses=losses)

            scheduler_generator.step()
            scheduler_mask_generator.step()

            logger.log_epoch(
                epoch, {
                    'generator': generator,
                    'mask_generator': mask_generator,
                    'optimizer_generator': optimizer_generator,
                    'optimizer_mask_generator': optimizer_mask_generator
                },
                inp=x,
                out=generated,
                save_w=True)
Exemplo n.º 26
0
class Framework:
    batch_axis = 1

    def __init__(self, config):
        self.config = config
        self.build_dataset()
        self.build_network()
        self.build_optimizer()

    def get_param(self, keys, default=None):
        node = self.config
        for key in keys.split('.'):
            if key in node:
                node = node[key]
            else:
                return default
        return node

    def cuda(self):
        self.model.cuda()
        if self.get_param('network.sync_bn', False):
            self.model = DataParallelWithCallback(self.model,
                                                  dim=self.batch_axis)
        else:
            self.model = nn.DataParallel(self.model, dim=self.batch_axis)

    def build_optimizer(self):
        args = copy(self.config['optimizer'])
        if args['type'] == 'SGD':
            optim_class = torch.optim.SGD
        elif args['type'] == 'Adam':
            optim_class = torch.optim.Adam
        args.pop('type')
        if isinstance(args['lr'], list):
            args['lr'] = args['lr'][0][0]
        self.optimizer = optim_class(self.model.parameters(), **args)

    def set_learning_rate(self, epoch):
        lrs = self.config['optimizer']['lr']
        if isinstance(lrs, list):
            c = 0
            for lr, num_epochs in lrs:
                c += num_epochs
                if epoch <= c:
                    break
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
            print('setting learning rate {}'.format(lr))

    def build_dataset(self):
        length = self.config['data']['length']
        stride = 5
        self.train_data = dataset.build_ucf101_dataset(
            'traindev1',
            transforms=transforms.Compose([
                transforms.RandomResizedCrop((224, 224), (0.5, 1.0),
                                             ratio=(3 / 4, 4 / 3)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ]),
            length=length,
            stride=stride,
            config=self.config)
        self.train_loader = torch.utils.data.DataLoader(
            self.train_data,
            sampler=sampler.RandomSampler(self.train_data,
                                          loop=self.config['data']['loop']),
            batch_size=self.config['batch_size'],
            num_workers=self.config['num_worker'])
        self.val_data, self.val_loader = self.build_test_dataset('val1')
        self.classes = self.train_data.classes

    def build_test_dataset(self, split):
        length = self.config['data']['length']
        stride = 5
        data = dataset.build_ucf101_dataset(split,
                                            transforms=transforms.Compose([
                                                transforms.CenterCrop(
                                                    (224, 224)),
                                                transforms.ToTensor()
                                            ]),
                                            length=length,
                                            stride=stride,
                                            config=self.config)
        loader = torch.utils.data.DataLoader(
            data,
            sampler=sampler.ValSampler(data, stride=length),
            batch_size=self.config['batch_size'],
            num_workers=self.config['num_worker'])
        return data, loader

    # Why Non_Blocking

    def prepare_data(self, data):
        frames = torch.stack(data[0], dim=0).cuda(non_blocking=True)
        labels = data[1].cuda(non_blocking=True)
        vids = data[2].numpy()
        return (frames, labels), vids

    def train_epoch(self, epoch):
        self.model.train()
        self.set_learning_rate(epoch)
        end = time.time()
        metrics = defaultdict(AverageMeter)
        for i, data in enumerate(self.train_loader):
            # measure data loading time
            metrics['data_time'].update(time.time() - end)

            args, _ = self.prepare_data(data)
            batch_size = args[0].size(1)
            result = self.train_batch(*args)
            loss = result['loss']
            for k, v in result.items():
                metrics[k].update(v.item(), batch_size)

            # compute gradient and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            metrics['batch_time'].update(time.time() - end)
            end = time.time()

            if i % self.config['print_freq'] == 0:
                print('Epoch: [{0}][{1}/{2}]\t'.format(epoch, i,
                                                       len(self.train_loader)),
                      end='')

                for k, v in metrics.items():
                    print('{key} {val.avg:.3f}'.format(key=k, val=v), end='\t')
                print()
            # if i > 200:
            #     break

        metrics.pop('batch_time')
        metrics.pop('data_time')
        return {k: v.avg for k, v in metrics.items()}

    def predict(self, dataloader):
        self.model.eval()
        metrics = defaultdict(list)
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                args, indices = self.prepare_data(data)
                result = self.predict_batch(*args)
                for k, v in result.items():
                    metrics[k].append(v.cpu().numpy())
                metrics['indices'].append(indices)
                if i % self.config['print_freq'] == 0:
                    print('Valid {}/{}'.format(i, len(dataloader)))
        for k in metrics:
            metrics[k] = np.concatenate(metrics[k], axis=0)
        return metrics

    def evaluate(self, dataloader):
        self.model.eval()
        metrics = defaultdict(AverageMeter)
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                args, indices = self.prepare_data(data)
                batch_size = args[0][0].size(0)
                result = self.eval_batch(*args)
                for k, v in result.items():
                    metrics[k].update(v.item(), batch_size)
                if i % self.config['print_freq'] == 0:
                    print('Valid {}/{}'.format(i, len(dataloader)))
        return {k: v.avg for k, v in metrics.items()}
Exemplo n.º 27
0
class Trainer:
    def __init__(self, logger, checkpoint, device_ids, config):
        self.config = config
        self.logger = logger
        self.device_ids = device_ids

        self.dataset, n_classes = get_dataset(config['dataset'],
                                              config['dataset_params'])

        if self.config['with_labels']:
            self.config['generator_params']['n_classes'] = n_classes
            self.config['discriminator_params']['n_classes'] = n_classes
            self.config['n_classes'] = n_classes
        else:
            self.config['generator_params']['n_classes'] = None
            self.config['discriminator_params']['n_classes'] = None

        self.restore(checkpoint)

        print("Generator...")
        print(self.generator)

        print("Discriminator...")
        print(self.discriminator)

    def restore(self, checkpoint):
        self.epoch = 0

        self.generator = DCGenerator(**self.config['generator_params'])
        self.generator = DataParallelWithCallback(self.generator,
                                                  device_ids=self.device_ids)
        self.optimizer_generator = torch.optim.Adam(
            params=self.generator.parameters(),
            lr=self.config['lr_generator'],
            betas=(self.config['b1_generator'], self.config['b2_generator']),
            weight_decay=0,
            eps=1e-8)

        self.discriminator = DCDiscriminator(
            **self.config['discriminator_params'])
        self.discriminator = DataParallelWithCallback(
            self.discriminator, device_ids=self.device_ids)
        self.optimizer_discriminator = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=self.config['lr_discriminator'],
            betas=(self.config['b1_discriminator'],
                   self.config['b2_discriminator']),
            weight_decay=0,
            eps=1e-8)

        if checkpoint is not None:
            data = torch.load(checkpoint)
            for key, value in data:
                if key == 'epoch':
                    self.epoch = value
                else:
                    self.__dict__[key].load_state_dict(value)

        lr_lambda = lambda epoch: 1 - epoch / self.config['num_epochs']
        self.scheduler_generator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_generator, lr_lambda, last_epoch=self.epoch - 1)
        self.scheduler_discriminator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_discriminator, lr_lambda, last_epoch=self.epoch - 1)

    def save(self):
        state_dict = {
            'epoch': self.epoch,
            'generator': self.generator.state_dict(),
            'optimizer_generator': self.optimizer_generator.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'optimizer_discriminator':
            self.optimizer_discriminator.state_dict()
        }

        torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth'))

    def train(self):
        loader = DataLoader(self.dataset,
                            batch_size=self.config['discriminator_bs'],
                            shuffle=False,
                            drop_last=True,
                            num_workers=self.config['num_workers'])
        noise = torch.zeros((max(self.config['generator_bs'],
                                 self.config['discriminator_bs']),
                             self.config['generator_params']['dim_z'])).cuda()
        if self.config['with_labels']:
            labels_fake = torch.zeros(
                max(self.config['generator_bs'],
                    self.config['discriminator_bs'])).type(
                        torch.LongTensor).cuda()
        else:
            labels_fake = None

        y_fake = None
        # Keep track of current iteration for update generator
        current_iter = 0
        loss_dict = defaultdict(lambda: 0.0)

        for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])):
            for data in tqdm(loader):
                self.generator.train()
                current_iter += 1

                images, labels_real = data
                y_real = None if not self.config['with_labels'] else labels_real

                self.optimizer_generator.zero_grad()
                self.optimizer_discriminator.zero_grad()

                z = noise.normal_()[:self.config['discriminator_bs']]
                if self.config['with_labels']:
                    y_fake = labels_fake.random_(
                        self.config['n_classes'])[:self.
                                                  config['discriminator_bs']]

                with torch.no_grad():
                    images_fake = self.generator(z, y_fake)

                logits_real = self.discriminator(images, y_real)
                logits_fake = self.discriminator(images_fake, y_fake)

                loss_fake = torch.relu(1 + logits_fake).mean()
                loss_real = torch.relu(1 - logits_real).mean()

                loss_dict['loss_fake'] += loss_fake.detach().cpu().numpy()
                loss_dict['loss_real'] += loss_real.detach().cpu().numpy()

                (loss_fake + loss_real).backward()
                self.optimizer_discriminator.step()

                if current_iter % self.config['num_discriminator_updates'] == 0:
                    self.optimizer_discriminator.zero_grad()
                    self.optimizer_generator.zero_grad()

                    z = noise.normal_()[:self.config['generator_bs']]
                    if self.config['with_labels']:
                        y_fake = labels_fake.random_(
                            self.config['n_classes'])[:self.
                                                      config['generator_bs']]

                    images_fake = self.generator(z, y_fake)
                    logits_fake = self.discriminator(images_fake, y_fake)

                    adversarial_loss = -logits_fake.mean()
                    loss_dict['adversarial_loss'] += adversarial_loss.detach(
                    ).cpu().numpy()

                    adversarial_loss.backward()
                    self.optimizer_generator.step()

            save_dict = {
                key: value / current_iter
                for key, value in loss_dict.items()
            }
            save_dict['lr'] = self.optimizer_generator.param_groups[0]['lr']

            loss_dict = defaultdict(lambda: 0.0)
            current_iter = 0

            with torch.no_grad():
                noise = noise.normal_()
                if self.config['with_labels']:
                    labels_fake = labels_fake.random_(self.config['n_classes'])
                images = self.generator(noise, labels_fake)
                self.logger.save_images(self.epoch, images)

            # if self.epoch % self.config['eval_frequency'] == 0 or self.epoch == self.config['num_epochs'] - 1:
            #     self.generator.eval()
            #
            #     if self.config['samples_evaluation'] != 0:
            #         generated = []
            #         with torch.no_grad():
            #             for i in range(self.config['samples_evaluation'] // noise.shape[0] + 1):
            #                 noise = noise.normal_()
            #                 if self.config['with_labels']:
            #                     labels_fake = labels_fake.random_(self.config['n_classes'])
            #
            #                 generated.append((127.5 * self.generator(noise, labels_fake) + 127.5).cpu().numpy())
            #
            #             generated = np.concatenate(generated)[:self.config['samples_evaluation']]
            #             self.logger.save_evaluation_images(self.epoch, generated)

            self.logger.log(self.epoch, save_dict)

            self.scheduler_generator.step()
            self.scheduler_discriminator.step()
            self.save()
print(opt)

torch.cuda.set_device(opt.device_id[0])

# ######################## Module #################################
print('Building model')
model = actionModel(opt.class_num,
                    batch_norm=True,
                    dropout=opt.dropout,
                    q=opt.q,
                    image_size=opt.img_size,
                    syn_bn=opt.syn_bn,
                    test_scheme=2)
print(model)
if opt.syn_bn:
    model = DataParallelWithCallback(model, device_ids=opt.device_id).cuda()
else:
    model = torch.nn.DataParallel(model, device_ids=opt.device_id).cuda()
print("Channels: " + str(model.module.channels))

# ########################Optimizer#########################
optimizer = torch.optim.SGD([{
    'params': model.module.RNN.parameters(),
    'lr': opt.LR[0]
}, {
    'params': model.module.ShortCut.parameters(),
    'lr': opt.LR[0]
}, {
    'params': model.module.classifier.parameters(),
    'lr': opt.LR[1]
}],
Exemplo n.º 29
0
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids):
    # Refer to *.yaml, "train_params" section.
    # This including epoch nums, etc ...
    train_params = config['train_params']

    # Define the optimizers for three sub-networks
    # Refer to Adam() document for details
    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))

    if checkpoint is not None:
        # Load in pretrained-models if set so
        # Models passed in are empty-initialized, which will be loaded in the following function
        start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
                                      optimizer_generator, optimizer_discriminator,
                                      None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
    else:
        start_epoch = 0

    # TODO: not sure what's this, it seems to define schedulers contronlling training details
    scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
                                          last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
                                        last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))

    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        # Augment the dataset according to "num_reapeat"
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])
    # Load in data with form that network can determine
    # Refer to pytorch DataLoader for details
    # 这里dataloader是一个FramesDataset类,它是 Dataset 的一个子类,所以可以有如下操作
    dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=2, drop_last=True)

    # Initialize two models for training
    # TODO: 阅读 generator 和 discrimator 的构造,key point detector 的部分应包含在 generator 当中
    generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
    # TODO: 阅读 discriminator,需注意的是上述 Generator 中也有 discriminator 存在,高清两者区别
    discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)

    # Transfer model to gpu type
    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
        discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)

    with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for x in dataloader:
                # 此处为前向传播,第一个返回值为loss,第二个为生成器的输出图片
                losses_generator, generated = generator_full(x)

                # 此处计算的loss有很多种类,此处取了每一种的平均并求和
                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                # 此处分别使用不同部分的优化器进行 step 更新
                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_kp_detector.step()
                optimizer_kp_detector.zero_grad()

                # 此处判断是否使用 GAN 的训练思想
                if train_params['loss_weights']['generator_gan'] != 0:
                    # 增加判别器的使用
                    optimizer_discriminator.zero_grad()
                    # 用判别器判定生成数据和源数据
                    losses_discriminator = discriminator_full(x, generated)
                    loss_values = [val.mean() for val in losses_discriminator.values()]
                    loss = sum(loss_values)

                    # 更新判别器
                    loss.backward()
                    optimizer_discriminator.step()
                    optimizer_discriminator.zero_grad()
                else:
                    losses_discriminator = {}

                # 注意此处的 update 是 python 中字典自带的更新方式
                losses_generator.update(losses_discriminator)
                losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
                logger.log_iter(losses=losses)

            # 此处为一个 epoch 的工作完成
            # TODO: 这是之前不确定是什么的数据结构,推断是对训练的schedule器的更新
            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()
            
            logger.log_epoch(epoch, {'generator': generator,
                                     'discriminator': discriminator,
                                     'kp_detector': kp_detector,
                                     'optimizer_generator': optimizer_generator,
                                     'optimizer_discriminator': optimizer_discriminator,
                                     'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)
Exemplo n.º 30
0
def train(config, generator, region_predictor, bg_predictor, checkpoint,
          log_dir, dataset, device_ids):
    train_params = config['train_params']

    optimizer = torch.optim.Adam(list(generator.parameters()) +
                                 list(region_predictor.parameters()) +
                                 list(bg_predictor.parameters()),
                                 lr=train_params['lr'],
                                 betas=(0.5, 0.999))

    if checkpoint is not None:
        start_epoch = Logger.load_cpk(checkpoint, generator, region_predictor,
                                      bg_predictor, None, optimizer, None)
    else:
        start_epoch = 0

    scheduler = MultiStepLR(optimizer,
                            train_params['epoch_milestones'],
                            gamma=0.1,
                            last_epoch=start_epoch - 1)
    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])

    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=train_params['dataloader_workers'],
                            drop_last=True)

    model = ReconstructionModel(region_predictor, bg_predictor, generator,
                                train_params)

    if torch.cuda.is_available():
        if ('use_sync_bn' in train_params) and train_params['use_sync_bn']:
            model = DataParallelWithCallback(model, device_ids=device_ids)
        else:
            model = torch.nn.DataParallel(model, device_ids=device_ids)

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for x in dataloader:
                losses, generated = model(x)
                loss_values = [val.mean() for val in losses.values()]
                loss = sum(loss_values)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                losses = {
                    key: value.mean().detach().data.cpu().numpy()
                    for key, value in losses.items()
                }
                logger.log_iter(losses=losses)

            scheduler.step()
            logger.log_epoch(epoch, {
                'generator': generator,
                'bg_predictor': bg_predictor,
                'region_predictor': region_predictor,
                'optimizer_reconstruction': optimizer
            },
                             inp=x,
                             out=generated)
Exemplo n.º 31
0
                               batch_size=max(param_grid['batch_size']))
 criterion = nn.MSELoss()
 best_valid_RMSE = np.full(1, np.inf)
 for grid in ParameterGrid(param_grid):
     print(f"===> Hyper-parameters = {grid}:")
     # random.seed(seed)
     # numpy.random.seed(seed)
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     importlib.reload(model)
     net = getattr(model,
                   model_name.upper())(input_channel=input_channel,
                                       input_size=input_size,
                                       output_size=output_size).to(device)
     if use_cuda:
         net = DataParallelWithCallback(net)
     print(f"===> Model:\n{list(net.modules())[0]}")
     print_param(net)
     if model_name == 'cnn0' or model_name == 'cnn1':
         optimizer = Adam(net.parameters(),
                          lr=grid['lr'],
                          l1=grid['l1'],
                          weight_decay=grid['l2'],
                          amsgrad=True)
     elif model_name == 'vgg19' or model_name == 'cnn2' or model_name == 'cnn3':
         optimizer = Adam([{
             'params':
             iter(param for name, param in net.named_parameters()
                  if 'channel_mask' in name),
             'l1':
             grid['l1_channel']
Exemplo n.º 32
0
class Model:
    def __init__(self,
                 hidden_dim,
                 lr,
                 hard_or_full_trip,
                 margin,
                 num_workers,
                 batch_size,
                 restore_iter,
                 total_iter,
                 save_name,
                 train_pid_num,
                 frame_num,
                 model_name,
                 train_source,
                 test_source,
                 img_size=64):

        self.save_name = save_name
        self.train_pid_num = train_pid_num
        self.train_source = train_source
        self.test_source = test_source

        self.hidden_dim = hidden_dim
        self.lr = lr
        self.hard_or_full_trip = hard_or_full_trip
        self.margin = margin
        self.frame_num = frame_num
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.model_name = model_name
        self.P, self.M = batch_size

        self.restore_iter = restore_iter
        self.total_iter = total_iter

        self.img_size = img_size

        self.encoder = SetNet(self.hidden_dim).float()
        self.encoder = DataParallelWithCallback(self.encoder)
        self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
        self.triplet_loss = DataParallelWithCallback(self.triplet_loss)
        self.encoder.cuda()
        self.triplet_loss.cuda()

        self.optimizer = optim.Adam([
            {'params': self.encoder.parameters()},
        ], lr=self.lr)

        self.hard_loss_metric = []
        self.full_loss_metric = []
        self.full_loss_num = []
        self.dist_list = []
        self.mean_dist = 0.01

        self.sample_type = 'all'

    def collate_fn(self, batch):
        batch_size = len(batch)
        feature_num = len(batch[0][0])
        seqs = [batch[i][0] for i in range(batch_size)]
        frame_sets = [batch[i][1] for i in range(batch_size)]
        view = [batch[i][2] for i in range(batch_size)]
        seq_type = [batch[i][3] for i in range(batch_size)]
        label = [batch[i][4] for i in range(batch_size)]
        batch = [seqs, view, seq_type, label, None]

        def select_frame(index):
            sample = seqs[index]
            frame_set = frame_sets[index]
            if self.sample_type == 'random':
                frame_id_list = random.choices(frame_set, k=self.frame_num)
                _ = [feature.loc[frame_id_list].values for feature in sample]
            else:
                _ = [feature.values for feature in sample]
            return _

        seqs = list(map(select_frame, range(len(seqs))))

        if self.sample_type == 'random':
            seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
        else:
            gpu_num = min(torch.cuda.device_count(), batch_size)
            batch_per_gpu = math.ceil(batch_size / gpu_num)
            batch_frames = [[
                                len(frame_sets[i])
                                for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
                                if i < batch_size
                                ] for _ in range(gpu_num)]
            if len(batch_frames[-1]) != batch_per_gpu:
                for _ in range(batch_per_gpu - len(batch_frames[-1])):
                    batch_frames[-1].append(0)
            max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)])
            seqs = [[
                        np.concatenate([
                                           seqs[i][j]
                                           for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
                                           if i < batch_size
                                           ], 0) for _ in range(gpu_num)]
                    for j in range(feature_num)]
            seqs = [np.asarray([
                                   np.pad(seqs[j][_],
                                          ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
                                          'constant',
                                          constant_values=0)
                                   for _ in range(gpu_num)])
                    for j in range(feature_num)]
            batch[4] = np.asarray(batch_frames)

        batch[0] = seqs
        return batch

    def fit(self):
        if self.restore_iter != 0:
            self.load(self.restore_iter)

        self.encoder.train()
        self.sample_type = 'random'
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        triplet_sampler = TripletSampler(self.train_source, self.batch_size)
        train_loader = tordata.DataLoader(
            dataset=self.train_source,
            batch_sampler=triplet_sampler,
            collate_fn=self.collate_fn,
            num_workers=self.num_workers)

        train_label_set = list(self.train_source.label_set)
        train_label_set.sort()

        _time1 = datetime.now()
        for seq, view, seq_type, label, batch_frame in train_loader:
            self.restore_iter += 1
            self.optimizer.zero_grad()

            for i in range(len(seq)):
                seq[i] = self.np2var(seq[i]).float()
            if batch_frame is not None:
                batch_frame = self.np2var(batch_frame).int()

            feature, label_prob = self.encoder(*seq, batch_frame)

            target_label = [train_label_set.index(l) for l in label]
            target_label = self.np2var(np.array(target_label)).long()

            triplet_feature = feature.permute(1, 0, 2).contiguous()
            triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
            (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
             ) = self.triplet_loss(triplet_feature, triplet_label)
            if self.hard_or_full_trip == 'hard':
                loss = hard_loss_metric.mean()
            elif self.hard_or_full_trip == 'full':
                loss = full_loss_metric.mean()

            self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy())
            self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy())
            self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy())
            self.dist_list.append(mean_dist.mean().data.cpu().numpy())

            if loss > 1e-9:
                loss.backward()
                self.optimizer.step()

            if self.restore_iter % 1000 == 0:
                print(datetime.now() - _time1)
                _time1 = datetime.now()

            if self.restore_iter % 100 == 0:
                self.save()
                print('iter {}:'.format(self.restore_iter), end='')
                print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='')
                print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='')
                print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='')
                self.mean_dist = np.mean(self.dist_list)
                print(', mean_dist={0:.8f}'.format(self.mean_dist), end='')
                print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='')
                print(', hard or full=%r' % self.hard_or_full_trip)
                sys.stdout.flush()
                self.hard_loss_metric = []
                self.full_loss_metric = []
                self.full_loss_num = []
                self.dist_list = []

            # Visualization using t-SNE
            # if self.restore_iter % 500 == 0:
            #     pca = TSNE(2)
            #     pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy())
            #     for i in range(self.P):
            #         plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0],
            #                     pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i])
            #
            #     plt.show()

            if self.restore_iter == self.total_iter:
                break

    def ts2var(self, x):
        return autograd.Variable(x).cuda()

    def np2var(self, x):
        return self.ts2var(torch.from_numpy(x))

    def transform(self, flag, batch_size=1):
        self.encoder.eval()
        source = self.test_source if flag == 'test' else self.train_source
        self.sample_type = 'all'
        data_loader = tordata.DataLoader(
            dataset=source,
            batch_size=batch_size,
            sampler=tordata.sampler.SequentialSampler(source),
            collate_fn=self.collate_fn,
            num_workers=self.num_workers)

        feature_list = list()
        view_list = list()
        seq_type_list = list()
        label_list = list()

        for i, x in enumerate(data_loader):
            seq, view, seq_type, label, batch_frame = x
            for j in range(len(seq)):
                seq[j] = self.np2var(seq[j]).float()
            if batch_frame is not None:
                batch_frame = self.np2var(batch_frame).int()
            # print(batch_frame, np.sum(batch_frame))

            feature, _ = self.encoder(*seq, batch_frame)
            n, num_bin, _ = feature.size()
            feature_list.append(feature.view(n, -1).data.cpu().numpy())
            view_list += view
            seq_type_list += seq_type
            label_list += label

        return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list

    def save(self):
        os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True)
        torch.save(self.encoder.state_dict(),
                   osp.join('checkpoint', self.model_name,
                            '{}-{:0>5}-encoder.ptm'.format(
                                self.save_name, self.restore_iter)))
        torch.save(self.optimizer.state_dict(),
                   osp.join('checkpoint', self.model_name,
                            '{}-{:0>5}-optimizer.ptm'.format(
                                self.save_name, self.restore_iter)))

    # restore_iter: iteration index of the checkpoint to load
    def load(self, restore_iter):
        self.encoder.load_state_dict(torch.load(osp.join(
            'checkpoint', self.model_name,
            '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter))))
        self.optimizer.load_state_dict(torch.load(osp.join(
            'checkpoint', self.model_name,
            '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))