コード例 #1
0
class Trainer:
    def __init__(self, logger, checkpoint, device_ids, config):
        self.config = config
        self.logger = logger
        self.device_ids = device_ids

        self.dataset, n_classes = get_dataset(config['dataset'],
                                              config['dataset_params'])

        if self.config['with_labels']:
            self.config['generator_params']['n_classes'] = n_classes
            self.config['discriminator_params']['n_classes'] = n_classes
            self.config['n_classes'] = n_classes
        else:
            self.config['generator_params']['n_classes'] = None
            self.config['discriminator_params']['n_classes'] = None

        self.restore(checkpoint)

        print("Generator...")
        print(self.generator)

        print("Discriminator...")
        print(self.discriminator)

    def restore(self, checkpoint):
        self.epoch = 0

        self.generator = DCGenerator(**self.config['generator_params'])
        self.generator = DataParallelWithCallback(self.generator,
                                                  device_ids=self.device_ids)
        self.optimizer_generator = torch.optim.Adam(
            params=self.generator.parameters(),
            lr=self.config['lr_generator'],
            betas=(self.config['b1_generator'], self.config['b2_generator']),
            weight_decay=0,
            eps=1e-8)

        self.discriminator = DCDiscriminator(
            **self.config['discriminator_params'])
        self.discriminator = DataParallelWithCallback(
            self.discriminator, device_ids=self.device_ids)
        self.optimizer_discriminator = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=self.config['lr_discriminator'],
            betas=(self.config['b1_discriminator'],
                   self.config['b2_discriminator']),
            weight_decay=0,
            eps=1e-8)

        if checkpoint is not None:
            data = torch.load(checkpoint)
            for key, value in data:
                if key == 'epoch':
                    self.epoch = value
                else:
                    self.__dict__[key].load_state_dict(value)

        lr_lambda = lambda epoch: 1 - epoch / self.config['num_epochs']
        self.scheduler_generator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_generator, lr_lambda, last_epoch=self.epoch - 1)
        self.scheduler_discriminator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_discriminator, lr_lambda, last_epoch=self.epoch - 1)

    def save(self):
        state_dict = {
            'epoch': self.epoch,
            'generator': self.generator.state_dict(),
            'optimizer_generator': self.optimizer_generator.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'optimizer_discriminator':
            self.optimizer_discriminator.state_dict()
        }

        torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth'))

    def train(self):
        loader = DataLoader(self.dataset,
                            batch_size=self.config['discriminator_bs'],
                            shuffle=False,
                            drop_last=True,
                            num_workers=self.config['num_workers'])
        noise = torch.zeros((max(self.config['generator_bs'],
                                 self.config['discriminator_bs']),
                             self.config['generator_params']['dim_z'])).cuda()
        if self.config['with_labels']:
            labels_fake = torch.zeros(
                max(self.config['generator_bs'],
                    self.config['discriminator_bs'])).type(
                        torch.LongTensor).cuda()
        else:
            labels_fake = None

        y_fake = None
        # Keep track of current iteration for update generator
        current_iter = 0
        loss_dict = defaultdict(lambda: 0.0)

        for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])):
            for data in tqdm(loader):
                self.generator.train()
                current_iter += 1

                images, labels_real = data
                y_real = None if not self.config['with_labels'] else labels_real

                self.optimizer_generator.zero_grad()
                self.optimizer_discriminator.zero_grad()

                z = noise.normal_()[:self.config['discriminator_bs']]
                if self.config['with_labels']:
                    y_fake = labels_fake.random_(
                        self.config['n_classes'])[:self.
                                                  config['discriminator_bs']]

                with torch.no_grad():
                    images_fake = self.generator(z, y_fake)

                logits_real = self.discriminator(images, y_real)
                logits_fake = self.discriminator(images_fake, y_fake)

                loss_fake = torch.relu(1 + logits_fake).mean()
                loss_real = torch.relu(1 - logits_real).mean()

                loss_dict['loss_fake'] += loss_fake.detach().cpu().numpy()
                loss_dict['loss_real'] += loss_real.detach().cpu().numpy()

                (loss_fake + loss_real).backward()
                self.optimizer_discriminator.step()

                if current_iter % self.config['num_discriminator_updates'] == 0:
                    self.optimizer_discriminator.zero_grad()
                    self.optimizer_generator.zero_grad()

                    z = noise.normal_()[:self.config['generator_bs']]
                    if self.config['with_labels']:
                        y_fake = labels_fake.random_(
                            self.config['n_classes'])[:self.
                                                      config['generator_bs']]

                    images_fake = self.generator(z, y_fake)
                    logits_fake = self.discriminator(images_fake, y_fake)

                    adversarial_loss = -logits_fake.mean()
                    loss_dict['adversarial_loss'] += adversarial_loss.detach(
                    ).cpu().numpy()

                    adversarial_loss.backward()
                    self.optimizer_generator.step()

            save_dict = {
                key: value / current_iter
                for key, value in loss_dict.items()
            }
            save_dict['lr'] = self.optimizer_generator.param_groups[0]['lr']

            loss_dict = defaultdict(lambda: 0.0)
            current_iter = 0

            with torch.no_grad():
                noise = noise.normal_()
                if self.config['with_labels']:
                    labels_fake = labels_fake.random_(self.config['n_classes'])
                images = self.generator(noise, labels_fake)
                self.logger.save_images(self.epoch, images)

            # if self.epoch % self.config['eval_frequency'] == 0 or self.epoch == self.config['num_epochs'] - 1:
            #     self.generator.eval()
            #
            #     if self.config['samples_evaluation'] != 0:
            #         generated = []
            #         with torch.no_grad():
            #             for i in range(self.config['samples_evaluation'] // noise.shape[0] + 1):
            #                 noise = noise.normal_()
            #                 if self.config['with_labels']:
            #                     labels_fake = labels_fake.random_(self.config['n_classes'])
            #
            #                 generated.append((127.5 * self.generator(noise, labels_fake) + 127.5).cpu().numpy())
            #
            #             generated = np.concatenate(generated)[:self.config['samples_evaluation']]
            #             self.logger.save_evaluation_images(self.epoch, generated)

            self.logger.log(self.epoch, save_dict)

            self.scheduler_generator.step()
            self.scheduler_discriminator.step()
            self.save()
コード例 #2
0
class Model:
    def __init__(self,
                 hidden_dim,
                 lr,
                 hard_or_full_trip,
                 margin,
                 num_workers,
                 batch_size,
                 restore_iter,
                 total_iter,
                 save_name,
                 train_pid_num,
                 frame_num,
                 model_name,
                 train_source,
                 test_source,
                 img_size=64):

        self.save_name = save_name
        self.train_pid_num = train_pid_num
        self.train_source = train_source
        self.test_source = test_source

        self.hidden_dim = hidden_dim
        self.lr = lr
        self.hard_or_full_trip = hard_or_full_trip
        self.margin = margin
        self.frame_num = frame_num
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.model_name = model_name
        self.P, self.M = batch_size

        self.restore_iter = restore_iter
        self.total_iter = total_iter

        self.img_size = img_size

        self.encoder = SetNet(self.hidden_dim).float()
        self.encoder = DataParallelWithCallback(self.encoder)
        self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
        self.triplet_loss = DataParallelWithCallback(self.triplet_loss)
        self.encoder.cuda()
        self.triplet_loss.cuda()

        self.optimizer = optim.Adam([
            {'params': self.encoder.parameters()},
        ], lr=self.lr)

        self.hard_loss_metric = []
        self.full_loss_metric = []
        self.full_loss_num = []
        self.dist_list = []
        self.mean_dist = 0.01

        self.sample_type = 'all'

    def collate_fn(self, batch):
        batch_size = len(batch)
        feature_num = len(batch[0][0])
        seqs = [batch[i][0] for i in range(batch_size)]
        frame_sets = [batch[i][1] for i in range(batch_size)]
        view = [batch[i][2] for i in range(batch_size)]
        seq_type = [batch[i][3] for i in range(batch_size)]
        label = [batch[i][4] for i in range(batch_size)]
        batch = [seqs, view, seq_type, label, None]

        def select_frame(index):
            sample = seqs[index]
            frame_set = frame_sets[index]
            if self.sample_type == 'random':
                frame_id_list = random.choices(frame_set, k=self.frame_num)
                _ = [feature.loc[frame_id_list].values for feature in sample]
            else:
                _ = [feature.values for feature in sample]
            return _

        seqs = list(map(select_frame, range(len(seqs))))

        if self.sample_type == 'random':
            seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
        else:
            gpu_num = min(torch.cuda.device_count(), batch_size)
            batch_per_gpu = math.ceil(batch_size / gpu_num)
            batch_frames = [[
                                len(frame_sets[i])
                                for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
                                if i < batch_size
                                ] for _ in range(gpu_num)]
            if len(batch_frames[-1]) != batch_per_gpu:
                for _ in range(batch_per_gpu - len(batch_frames[-1])):
                    batch_frames[-1].append(0)
            max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)])
            seqs = [[
                        np.concatenate([
                                           seqs[i][j]
                                           for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
                                           if i < batch_size
                                           ], 0) for _ in range(gpu_num)]
                    for j in range(feature_num)]
            seqs = [np.asarray([
                                   np.pad(seqs[j][_],
                                          ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
                                          'constant',
                                          constant_values=0)
                                   for _ in range(gpu_num)])
                    for j in range(feature_num)]
            batch[4] = np.asarray(batch_frames)

        batch[0] = seqs
        return batch

    def fit(self):
        if self.restore_iter != 0:
            self.load(self.restore_iter)

        self.encoder.train()
        self.sample_type = 'random'
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        triplet_sampler = TripletSampler(self.train_source, self.batch_size)
        train_loader = tordata.DataLoader(
            dataset=self.train_source,
            batch_sampler=triplet_sampler,
            collate_fn=self.collate_fn,
            num_workers=self.num_workers)

        train_label_set = list(self.train_source.label_set)
        train_label_set.sort()

        _time1 = datetime.now()
        for seq, view, seq_type, label, batch_frame in train_loader:
            self.restore_iter += 1
            self.optimizer.zero_grad()

            for i in range(len(seq)):
                seq[i] = self.np2var(seq[i]).float()
            if batch_frame is not None:
                batch_frame = self.np2var(batch_frame).int()

            feature, label_prob = self.encoder(*seq, batch_frame)

            target_label = [train_label_set.index(l) for l in label]
            target_label = self.np2var(np.array(target_label)).long()

            triplet_feature = feature.permute(1, 0, 2).contiguous()
            triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
            (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
             ) = self.triplet_loss(triplet_feature, triplet_label)
            if self.hard_or_full_trip == 'hard':
                loss = hard_loss_metric.mean()
            elif self.hard_or_full_trip == 'full':
                loss = full_loss_metric.mean()

            self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy())
            self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy())
            self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy())
            self.dist_list.append(mean_dist.mean().data.cpu().numpy())

            if loss > 1e-9:
                loss.backward()
                self.optimizer.step()

            if self.restore_iter % 1000 == 0:
                print(datetime.now() - _time1)
                _time1 = datetime.now()

            if self.restore_iter % 100 == 0:
                self.save()
                print('iter {}:'.format(self.restore_iter), end='')
                print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='')
                print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='')
                print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='')
                self.mean_dist = np.mean(self.dist_list)
                print(', mean_dist={0:.8f}'.format(self.mean_dist), end='')
                print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='')
                print(', hard or full=%r' % self.hard_or_full_trip)
                sys.stdout.flush()
                self.hard_loss_metric = []
                self.full_loss_metric = []
                self.full_loss_num = []
                self.dist_list = []

            # Visualization using t-SNE
            # if self.restore_iter % 500 == 0:
            #     pca = TSNE(2)
            #     pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy())
            #     for i in range(self.P):
            #         plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0],
            #                     pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i])
            #
            #     plt.show()

            if self.restore_iter == self.total_iter:
                break

    def ts2var(self, x):
        return autograd.Variable(x).cuda()

    def np2var(self, x):
        return self.ts2var(torch.from_numpy(x))

    def transform(self, flag, batch_size=1):
        self.encoder.eval()
        source = self.test_source if flag == 'test' else self.train_source
        self.sample_type = 'all'
        data_loader = tordata.DataLoader(
            dataset=source,
            batch_size=batch_size,
            sampler=tordata.sampler.SequentialSampler(source),
            collate_fn=self.collate_fn,
            num_workers=self.num_workers)

        feature_list = list()
        view_list = list()
        seq_type_list = list()
        label_list = list()

        for i, x in enumerate(data_loader):
            seq, view, seq_type, label, batch_frame = x
            for j in range(len(seq)):
                seq[j] = self.np2var(seq[j]).float()
            if batch_frame is not None:
                batch_frame = self.np2var(batch_frame).int()
            # print(batch_frame, np.sum(batch_frame))

            feature, _ = self.encoder(*seq, batch_frame)
            n, num_bin, _ = feature.size()
            feature_list.append(feature.view(n, -1).data.cpu().numpy())
            view_list += view
            seq_type_list += seq_type
            label_list += label

        return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list

    def save(self):
        os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True)
        torch.save(self.encoder.state_dict(),
                   osp.join('checkpoint', self.model_name,
                            '{}-{:0>5}-encoder.ptm'.format(
                                self.save_name, self.restore_iter)))
        torch.save(self.optimizer.state_dict(),
                   osp.join('checkpoint', self.model_name,
                            '{}-{:0>5}-optimizer.ptm'.format(
                                self.save_name, self.restore_iter)))

    # restore_iter: iteration index of the checkpoint to load
    def load(self, restore_iter):
        self.encoder.load_state_dict(torch.load(osp.join(
            'checkpoint', self.model_name,
            '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter))))
        self.optimizer.load_state_dict(torch.load(osp.join(
            'checkpoint', self.model_name,
            '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))
コード例 #3
0
class Framework:
    batch_axis = 1

    def __init__(self, config):
        self.config = config
        self.build_dataset()
        self.build_network()
        self.build_optimizer()

    def get_param(self, keys, default=None):
        node = self.config
        for key in keys.split('.'):
            if key in node:
                node = node[key]
            else:
                return default
        return node

    def cuda(self):
        self.model.cuda()
        if self.get_param('network.sync_bn', False):
            self.model = DataParallelWithCallback(self.model,
                                                  dim=self.batch_axis)
        else:
            self.model = nn.DataParallel(self.model, dim=self.batch_axis)

    def build_optimizer(self):
        args = copy(self.config['optimizer'])
        if args['type'] == 'SGD':
            optim_class = torch.optim.SGD
        elif args['type'] == 'Adam':
            optim_class = torch.optim.Adam
        args.pop('type')
        if isinstance(args['lr'], list):
            args['lr'] = args['lr'][0][0]
        self.optimizer = optim_class(self.model.parameters(), **args)

    def set_learning_rate(self, epoch):
        lrs = self.config['optimizer']['lr']
        if isinstance(lrs, list):
            c = 0
            for lr, num_epochs in lrs:
                c += num_epochs
                if epoch <= c:
                    break
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
            print('setting learning rate {}'.format(lr))

    def build_dataset(self):
        length = self.config['data']['length']
        stride = 5
        self.train_data = dataset.build_ucf101_dataset(
            'traindev1',
            transforms=transforms.Compose([
                transforms.RandomResizedCrop((224, 224), (0.5, 1.0),
                                             ratio=(3 / 4, 4 / 3)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ]),
            length=length,
            stride=stride,
            config=self.config)
        self.train_loader = torch.utils.data.DataLoader(
            self.train_data,
            sampler=sampler.RandomSampler(self.train_data,
                                          loop=self.config['data']['loop']),
            batch_size=self.config['batch_size'],
            num_workers=self.config['num_worker'])
        self.val_data, self.val_loader = self.build_test_dataset('val1')
        self.classes = self.train_data.classes

    def build_test_dataset(self, split):
        length = self.config['data']['length']
        stride = 5
        data = dataset.build_ucf101_dataset(split,
                                            transforms=transforms.Compose([
                                                transforms.CenterCrop(
                                                    (224, 224)),
                                                transforms.ToTensor()
                                            ]),
                                            length=length,
                                            stride=stride,
                                            config=self.config)
        loader = torch.utils.data.DataLoader(
            data,
            sampler=sampler.ValSampler(data, stride=length),
            batch_size=self.config['batch_size'],
            num_workers=self.config['num_worker'])
        return data, loader

    # Why Non_Blocking

    def prepare_data(self, data):
        frames = torch.stack(data[0], dim=0).cuda(non_blocking=True)
        labels = data[1].cuda(non_blocking=True)
        vids = data[2].numpy()
        return (frames, labels), vids

    def train_epoch(self, epoch):
        self.model.train()
        self.set_learning_rate(epoch)
        end = time.time()
        metrics = defaultdict(AverageMeter)
        for i, data in enumerate(self.train_loader):
            # measure data loading time
            metrics['data_time'].update(time.time() - end)

            args, _ = self.prepare_data(data)
            batch_size = args[0].size(1)
            result = self.train_batch(*args)
            loss = result['loss']
            for k, v in result.items():
                metrics[k].update(v.item(), batch_size)

            # compute gradient and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            metrics['batch_time'].update(time.time() - end)
            end = time.time()

            if i % self.config['print_freq'] == 0:
                print('Epoch: [{0}][{1}/{2}]\t'.format(epoch, i,
                                                       len(self.train_loader)),
                      end='')

                for k, v in metrics.items():
                    print('{key} {val.avg:.3f}'.format(key=k, val=v), end='\t')
                print()
            # if i > 200:
            #     break

        metrics.pop('batch_time')
        metrics.pop('data_time')
        return {k: v.avg for k, v in metrics.items()}

    def predict(self, dataloader):
        self.model.eval()
        metrics = defaultdict(list)
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                args, indices = self.prepare_data(data)
                result = self.predict_batch(*args)
                for k, v in result.items():
                    metrics[k].append(v.cpu().numpy())
                metrics['indices'].append(indices)
                if i % self.config['print_freq'] == 0:
                    print('Valid {}/{}'.format(i, len(dataloader)))
        for k in metrics:
            metrics[k] = np.concatenate(metrics[k], axis=0)
        return metrics

    def evaluate(self, dataloader):
        self.model.eval()
        metrics = defaultdict(AverageMeter)
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                args, indices = self.prepare_data(data)
                batch_size = args[0][0].size(0)
                result = self.eval_batch(*args)
                for k, v in result.items():
                    metrics[k].update(v.item(), batch_size)
                if i % self.config['print_freq'] == 0:
                    print('Valid {}/{}'.format(i, len(dataloader)))
        return {k: v.avg for k, v in metrics.items()}
コード例 #4
0
ファイル: main.py プロジェクト: wuziniu/Fixed_Kernel_CNN
 print(f"===> Hyper-parameters = {grid}:")
 # random.seed(seed)
 # numpy.random.seed(seed)
 torch.manual_seed(seed)
 torch.cuda.manual_seed_all(seed)
 importlib.reload(model)
 net = getattr(model,
               model_name.upper())(input_channel=input_channel,
                                   input_size=input_size,
                                   output_size=output_size).to(device)
 if use_cuda:
     net = DataParallelWithCallback(net)
 print(f"===> Model:\n{list(net.modules())[0]}")
 print_param(net)
 if model_name == 'cnn0' or model_name == 'cnn1':
     optimizer = Adam(net.parameters(),
                      lr=grid['lr'],
                      l1=grid['l1'],
                      weight_decay=grid['l2'],
                      amsgrad=True)
 elif model_name == 'vgg19' or model_name == 'cnn2' or model_name == 'cnn3':
     optimizer = Adam([{
         'params':
         iter(param for name, param in net.named_parameters()
              if 'channel_mask' in name),
         'l1':
         grid['l1_channel']
     }, {
         'params':
         iter(param for name, param in net.named_parameters()
              if 'spatial_mask' in name),
コード例 #5
0
ファイル: train.py プロジェクト: AliaksandrSiarohin/eq-inv
class Trainer:
    def __init__(self, logger, checkpoint, device_ids, config):
        self.BtoA = config['cycle_loss_weight'] != 0
        self.config = config
        self.logger = logger
        self.device_ids = device_ids

        self.restore(checkpoint)

        print("Generator...")
        print(self.generatorB)

        print("Discriminator...")
        print(self.discriminatorB)

        transform = list()
        transform.append(T.Resize(config['load_size']))
        transform.append(T.RandomCrop(config['crop_size']))
        transform.append(T.ToTensor())
        transform.append(T.Normalize(mean=(0.5, 0.5, 0.5),
                                     std=(0.5, 0.5, 0.5)))
        transform = T.Compose(transform)

        self.dataset = ABDataset(config['root_dir'],
                                 partition='train',
                                 transform=transform)

    def restore(self, checkpoint):
        self.epoch = 0

        self.generatorB = Generator(**self.config['generator_params']).cuda()
        self.generatorB = DataParallelWithCallback(self.generatorB,
                                                   device_ids=self.device_ids)
        self.optimizer_generatorB = torch.optim.Adam(
            self.generatorB.parameters(),
            lr=self.config['lr_generator'],
            betas=(0.5, 0.999))

        self.discriminatorB = Discriminator(
            **self.config['discriminator_params']).cuda()
        self.discriminatorB = DataParallelWithCallback(
            self.discriminatorB, device_ids=self.device_ids)
        self.optimizer_discriminatorB = torch.optim.Adam(
            self.discriminatorB.parameters(),
            lr=self.config['lr_discriminator'],
            betas=(0.5, 0.999))

        if self.BtoA:
            self.generatorA = Generator(
                **self.config['generator_params']).cuda()
            self.generatorA = DataParallelWithCallback(
                self.generatorA, device_ids=self.device_ids)
            self.optimizer_generatorA = torch.optim.Adam(
                self.generatorA.parameters(),
                lr=self.config['lr_generator'],
                betas=(0.5, 0.999))

            self.discriminatorA = Discriminator(
                **self.config['discriminator_params']).cuda()
            self.discriminatorA = DataParallelWithCallback(
                self.discriminatorA, device_ids=self.device_ids)
            self.optimizer_discriminatorA = torch.optim.Adam(
                self.discriminatorA.parameters(),
                lr=self.config['lr_discriminator'],
                betas=(0.5, 0.999))

        if checkpoint is not None:
            data = torch.load(checkpoint)
            for key, value in data.items():
                if key == 'epoch':
                    self.epoch = value
                else:
                    self.__dict__[key].load_state_dict(value)

        lr_lambda = lambda epoch: min(
            1, 2 - 2 * epoch / self.config['num_epochs'])
        self.scheduler_generatorB = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_generatorB, lr_lambda, last_epoch=self.epoch - 1)
        self.scheduler_discriminatorB = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_discriminatorB,
            lr_lambda,
            last_epoch=self.epoch - 1)

        if self.BtoA:
            self.scheduler_generatorA = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_generatorA,
                lr_lambda,
                last_epoch=self.epoch - 1)
            self.scheduler_discriminatorA = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_discriminatorA,
                lr_lambda,
                last_epoch=self.epoch - 1)

    def save(self):
        state_dict = {
            'epoch': self.epoch,
            'generatorB': self.generatorB.state_dict(),
            'optimizer_generatorB': self.optimizer_generatorB.state_dict(),
            'discriminatorB': self.discriminatorB.state_dict(),
            'optimizer_discriminatorB':
            self.optimizer_discriminatorB.state_dict()
        }

        if self.BtoA:
            state_dict.update({
                'generatorA':
                self.generatorA.state_dict(),
                'optimizer_generatorA':
                self.optimizer_generatorA.state_dict(),
                'discriminatorA':
                self.discriminatorA.state_dict(),
                'optimizer_discriminatorA':
                self.optimizer_discriminatorA.state_dict()
            })

        torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth'))

    def train(self):
        np.random.seed(0)
        loader = DataLoader(self.dataset,
                            batch_size=self.config['bs'],
                            shuffle=False,
                            drop_last=True,
                            num_workers=4)
        images_fixed = None

        for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])):
            loss_dict = defaultdict(lambda: 0.0)
            iteration_count = 1
            for inp in tqdm(loader):
                images_A = inp['A'].cuda()
                images_B = inp['B'].cuda()

                if images_fixed is None:
                    images_fixed = {'A': images_A, 'B': images_B}
                    transform_fixed = Transform(
                        images_A.shape[0], **self.config['transform_params'])

                if self.config['identity_loss_weight'] != 0:
                    images_trg = self.generatorB(images_B, source=False)
                    identity_loss = l1(images_trg, images_B)
                    identity_loss = self.config[
                        'identity_loss_weight'] * identity_loss
                    identity_loss.backward()

                    loss_dict['identity_loss_B'] += identity_loss.detach().cpu(
                    ).numpy()

                if self.config['identity_loss_weight'] != 0 and self.BtoA:
                    images_trg = self.generatorA(images_A, source=False)
                    identity_loss = l1(images_trg, images_A)
                    identity_loss = self.config[
                        'identity_loss_weight'] * identity_loss
                    identity_loss.backward()

                    loss_dict['identity_loss_A'] += identity_loss.detach().cpu(
                    ).numpy()

                generator_loss = 0
                images_generatedB = self.generatorB(images_A, source=True)
                logits = self.discriminatorB(images_generatedB)
                adversarial_loss = gan_loss_generator(
                    logits, self.config['gan_loss_type'])
                adversarial_loss = self.config[
                    'adversarial_loss_weight'] * adversarial_loss
                generator_loss += adversarial_loss
                loss_dict['adversarial_loss_B'] += adversarial_loss.detach(
                ).cpu().numpy()

                if self.BtoA:
                    images_generatedA = self.generatorA(images_B, source=True)
                    logits = self.discriminatorA(images_generatedA)
                    adversarial_loss = gan_loss_generator(
                        logits, self.config['gan_loss_type'])
                    adversarial_loss = self.config[
                        'adversarial_loss_weight'] * adversarial_loss
                    generator_loss += adversarial_loss
                    loss_dict['adversarial_loss_A'] += adversarial_loss.detach(
                    ).cpu().numpy()

                if self.config['equivariance_loss_weight_generator'] != 0:
                    transform = Transform(images_generatedB.shape[0],
                                          **self.config['transform_params'])
                    images_A_transformed = transform.transform_frame(images_A)
                    loss = corr(
                        self.generatorB(images_A_transformed, source=True),
                        transform.transform_frame(images_generatedB))
                    loss = self.config[
                        'equivariance_loss_weight_generator'] * loss
                    generator_loss += loss
                    loss_dict['equivariance_generator_B'] += loss.detach().cpu(
                    ).numpy()

                if self.config[
                        'equivariance_loss_weight_generator'] != 0 and self.BtoA:
                    transform = Transform(images_generatedA.shape[0],
                                          **self.config['transform_params'])
                    images_B_transformed = transform.transform_frame(images_B)
                    loss = corr(
                        self.generatorB(images_B_transformed, source=True),
                        transform.transform_frame(images_generatedA))
                    loss = self.config[
                        'equivariance_loss_weight_generator'] * loss
                    generator_loss += loss
                    loss_dict['equivariance_generator_A'] += loss.detach().cpu(
                    ).numpy()

                if self.BtoA and self.config[
                        'cycle_loss_weight'] != 0 and self.BtoA:
                    images_cycled = self.generatorA(images_generatedB,
                                                    source=True)
                    cycle_loss = torch.abs(images_cycled - images_A).mean()
                    cycle_loss = self.config['cycle_loss_weight'] * cycle_loss
                    generator_loss += cycle_loss
                    loss_dict['cycle_loss_B'] += cycle_loss.detach().cpu(
                    ).numpy()

                    images_cycled = self.generatorB(images_generatedA,
                                                    source=True)
                    cycle_loss = torch.abs(images_cycled - images_B).mean()
                    cycle_loss = self.config['cycle_loss_weight'] * cycle_loss
                    generator_loss += cycle_loss
                    loss_dict['cycle_loss_A'] += cycle_loss.detach().cpu(
                    ).numpy()

                generator_loss.backward()

                self.optimizer_generatorB.step()
                self.optimizer_generatorB.zero_grad()
                self.optimizer_discriminatorB.zero_grad()

                if self.BtoA:
                    self.optimizer_generatorA.step()
                    self.optimizer_generatorA.zero_grad()
                    self.optimizer_discriminatorA.zero_grad()

                logits_fake = self.discriminatorB(images_generatedB.detach())
                logits_real = self.discriminatorB(images_B)
                discriminator_loss = gan_loss_discriminator(
                    logits_real, logits_fake, self.config['gan_loss_type'])
                loss_dict['discriminator_loss_B'] += discriminator_loss.detach(
                ).cpu().numpy()

                if self.config['equivariance_loss_weight_discriminator'] != 0:
                    images_join = torch.cat(
                        [images_generatedB.detach(), images_B])
                    logits_join = torch.cat([logits_fake, logits_real])

                    transform = Transform(images_join.shape[0],
                                          **self.config['transform_params'])
                    images_transformed = transform.transform_frame(images_join)
                    loss = corr(self.discriminatorB(images_transformed),
                                transform.transform_frame(logits_join))

                    loss = self.config[
                        'equivariance_loss_weight_discriminator'] * loss
                    discriminator_loss += loss
                    loss_dict['equivariance_discriminator_B'] += loss.detach(
                    ).cpu().numpy()

                discriminator_loss.backward()

                self.optimizer_discriminatorB.step()
                self.optimizer_discriminatorB.zero_grad()
                self.optimizer_generatorB.zero_grad()

                if self.BtoA:
                    logits_fake = self.discriminatorA(
                        images_generatedA.detach())
                    logits_real = self.discriminatorA(images_A)
                    discriminator_loss = gan_loss_discriminator(
                        logits_real, logits_fake, self.config['gan_loss_type'])
                    loss_dict[
                        'discriminator_loss_B'] += discriminator_loss.detach(
                        ).cpu().numpy()

                    if self.config[
                            'equivariance_loss_weight_discriminator'] != 0:
                        images_join = torch.cat(
                            [images_generatedA.detach(), images_A])
                        logits_join = torch.cat([logits_fake, logits_real])

                        transform = Transform(
                            images_join.shape[0],
                            **self.config['transform_params'])
                        images_transformed = transform.transform_frame(
                            images_join)
                        loss = corr(self.discriminatorA(images_transformed),
                                    transform.transform_frame(logits_join))

                        loss = self.config[
                            'equivariance_loss_weight_discriminator'] * loss
                        discriminator_loss += loss
                        loss_dict[
                            'equivariance_discriminator_B'] += loss.detach(
                            ).cpu().numpy()

                    discriminator_loss.backward()
                    self.optimizer_discriminatorA.step()
                    self.optimizer_discriminatorA.zero_grad()
                    self.optimizer_generatorA.zero_grad()

                iteration_count += 1

            with torch.no_grad():
                if not self.BtoA:
                    self.generatorB.eval()
                    transformed = transform_fixed.transform_frame(
                        images_fixed['A'])
                    self.logger.save_images(
                        self.epoch, images_fixed['A'],
                        self.generatorB(images_fixed['A'], source=True),
                        transformed, self.generatorB(transformed, source=True))
                    self.generatorB.train()
                else:
                    self.generatorA.eval()
                    self.generatorB.eval()

                    images_generatedB = self.generatorB(images_fixed['A'],
                                                        source=True)
                    images_generatedA = self.generatorA(images_fixed['B'],
                                                        source=True)

                    transformed = transform_fixed.transform_frame(
                        images_fixed['A'])
                    self.logger.save_images(
                        self.epoch, images_fixed['A'], images_generatedB,
                        transformed, self.generatorB(transformed, source=True),
                        self.generatorA(images_generatedB, source=True),
                        images_fixed['B'], images_generatedA,
                        self.generatorB(images_generatedA, source=True))

                    self.generatorA.train()
                    self.generatorB.train()

            self.scheduler_generatorB.step()
            self.scheduler_discriminatorB.step()
            if self.BtoA:
                self.scheduler_generatorA.step()
                self.scheduler_discriminatorA.step()

            save_dict = {
                key: value / iteration_count
                for key, value in loss_dict.items()
            }
            save_dict['lr'] = self.optimizer_generatorB.param_groups[0]['lr']

            self.logger.log(self.epoch, save_dict)
            self.save()