class Trainer: def __init__(self, logger, checkpoint, device_ids, config): self.config = config self.logger = logger self.device_ids = device_ids self.dataset, n_classes = get_dataset(config['dataset'], config['dataset_params']) if self.config['with_labels']: self.config['generator_params']['n_classes'] = n_classes self.config['discriminator_params']['n_classes'] = n_classes self.config['n_classes'] = n_classes else: self.config['generator_params']['n_classes'] = None self.config['discriminator_params']['n_classes'] = None self.restore(checkpoint) print("Generator...") print(self.generator) print("Discriminator...") print(self.discriminator) def restore(self, checkpoint): self.epoch = 0 self.generator = DCGenerator(**self.config['generator_params']) self.generator = DataParallelWithCallback(self.generator, device_ids=self.device_ids) self.optimizer_generator = torch.optim.Adam( params=self.generator.parameters(), lr=self.config['lr_generator'], betas=(self.config['b1_generator'], self.config['b2_generator']), weight_decay=0, eps=1e-8) self.discriminator = DCDiscriminator( **self.config['discriminator_params']) self.discriminator = DataParallelWithCallback( self.discriminator, device_ids=self.device_ids) self.optimizer_discriminator = torch.optim.Adam( params=self.discriminator.parameters(), lr=self.config['lr_discriminator'], betas=(self.config['b1_discriminator'], self.config['b2_discriminator']), weight_decay=0, eps=1e-8) if checkpoint is not None: data = torch.load(checkpoint) for key, value in data: if key == 'epoch': self.epoch = value else: self.__dict__[key].load_state_dict(value) lr_lambda = lambda epoch: 1 - epoch / self.config['num_epochs'] self.scheduler_generator = torch.optim.lr_scheduler.LambdaLR( self.optimizer_generator, lr_lambda, last_epoch=self.epoch - 1) self.scheduler_discriminator = torch.optim.lr_scheduler.LambdaLR( self.optimizer_discriminator, lr_lambda, last_epoch=self.epoch - 1) def save(self): state_dict = { 'epoch': self.epoch, 'generator': self.generator.state_dict(), 'optimizer_generator': self.optimizer_generator.state_dict(), 'discriminator': self.discriminator.state_dict(), 'optimizer_discriminator': self.optimizer_discriminator.state_dict() } torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth')) def train(self): loader = DataLoader(self.dataset, batch_size=self.config['discriminator_bs'], shuffle=False, drop_last=True, num_workers=self.config['num_workers']) noise = torch.zeros((max(self.config['generator_bs'], self.config['discriminator_bs']), self.config['generator_params']['dim_z'])).cuda() if self.config['with_labels']: labels_fake = torch.zeros( max(self.config['generator_bs'], self.config['discriminator_bs'])).type( torch.LongTensor).cuda() else: labels_fake = None y_fake = None # Keep track of current iteration for update generator current_iter = 0 loss_dict = defaultdict(lambda: 0.0) for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])): for data in tqdm(loader): self.generator.train() current_iter += 1 images, labels_real = data y_real = None if not self.config['with_labels'] else labels_real self.optimizer_generator.zero_grad() self.optimizer_discriminator.zero_grad() z = noise.normal_()[:self.config['discriminator_bs']] if self.config['with_labels']: y_fake = labels_fake.random_( self.config['n_classes'])[:self. config['discriminator_bs']] with torch.no_grad(): images_fake = self.generator(z, y_fake) logits_real = self.discriminator(images, y_real) logits_fake = self.discriminator(images_fake, y_fake) loss_fake = torch.relu(1 + logits_fake).mean() loss_real = torch.relu(1 - logits_real).mean() loss_dict['loss_fake'] += loss_fake.detach().cpu().numpy() loss_dict['loss_real'] += loss_real.detach().cpu().numpy() (loss_fake + loss_real).backward() self.optimizer_discriminator.step() if current_iter % self.config['num_discriminator_updates'] == 0: self.optimizer_discriminator.zero_grad() self.optimizer_generator.zero_grad() z = noise.normal_()[:self.config['generator_bs']] if self.config['with_labels']: y_fake = labels_fake.random_( self.config['n_classes'])[:self. config['generator_bs']] images_fake = self.generator(z, y_fake) logits_fake = self.discriminator(images_fake, y_fake) adversarial_loss = -logits_fake.mean() loss_dict['adversarial_loss'] += adversarial_loss.detach( ).cpu().numpy() adversarial_loss.backward() self.optimizer_generator.step() save_dict = { key: value / current_iter for key, value in loss_dict.items() } save_dict['lr'] = self.optimizer_generator.param_groups[0]['lr'] loss_dict = defaultdict(lambda: 0.0) current_iter = 0 with torch.no_grad(): noise = noise.normal_() if self.config['with_labels']: labels_fake = labels_fake.random_(self.config['n_classes']) images = self.generator(noise, labels_fake) self.logger.save_images(self.epoch, images) # if self.epoch % self.config['eval_frequency'] == 0 or self.epoch == self.config['num_epochs'] - 1: # self.generator.eval() # # if self.config['samples_evaluation'] != 0: # generated = [] # with torch.no_grad(): # for i in range(self.config['samples_evaluation'] // noise.shape[0] + 1): # noise = noise.normal_() # if self.config['with_labels']: # labels_fake = labels_fake.random_(self.config['n_classes']) # # generated.append((127.5 * self.generator(noise, labels_fake) + 127.5).cpu().numpy()) # # generated = np.concatenate(generated)[:self.config['samples_evaluation']] # self.logger.save_evaluation_images(self.epoch, generated) self.logger.log(self.epoch, save_dict) self.scheduler_generator.step() self.scheduler_discriminator.step() self.save()
class Model: def __init__(self, hidden_dim, lr, hard_or_full_trip, margin, num_workers, batch_size, restore_iter, total_iter, save_name, train_pid_num, frame_num, model_name, train_source, test_source, img_size=64): self.save_name = save_name self.train_pid_num = train_pid_num self.train_source = train_source self.test_source = test_source self.hidden_dim = hidden_dim self.lr = lr self.hard_or_full_trip = hard_or_full_trip self.margin = margin self.frame_num = frame_num self.num_workers = num_workers self.batch_size = batch_size self.model_name = model_name self.P, self.M = batch_size self.restore_iter = restore_iter self.total_iter = total_iter self.img_size = img_size self.encoder = SetNet(self.hidden_dim).float() self.encoder = DataParallelWithCallback(self.encoder) self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float() self.triplet_loss = DataParallelWithCallback(self.triplet_loss) self.encoder.cuda() self.triplet_loss.cuda() self.optimizer = optim.Adam([ {'params': self.encoder.parameters()}, ], lr=self.lr) self.hard_loss_metric = [] self.full_loss_metric = [] self.full_loss_num = [] self.dist_list = [] self.mean_dist = 0.01 self.sample_type = 'all' def collate_fn(self, batch): batch_size = len(batch) feature_num = len(batch[0][0]) seqs = [batch[i][0] for i in range(batch_size)] frame_sets = [batch[i][1] for i in range(batch_size)] view = [batch[i][2] for i in range(batch_size)] seq_type = [batch[i][3] for i in range(batch_size)] label = [batch[i][4] for i in range(batch_size)] batch = [seqs, view, seq_type, label, None] def select_frame(index): sample = seqs[index] frame_set = frame_sets[index] if self.sample_type == 'random': frame_id_list = random.choices(frame_set, k=self.frame_num) _ = [feature.loc[frame_id_list].values for feature in sample] else: _ = [feature.values for feature in sample] return _ seqs = list(map(select_frame, range(len(seqs)))) if self.sample_type == 'random': seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)] else: gpu_num = min(torch.cuda.device_count(), batch_size) batch_per_gpu = math.ceil(batch_size / gpu_num) batch_frames = [[ len(frame_sets[i]) for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) if i < batch_size ] for _ in range(gpu_num)] if len(batch_frames[-1]) != batch_per_gpu: for _ in range(batch_per_gpu - len(batch_frames[-1])): batch_frames[-1].append(0) max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)]) seqs = [[ np.concatenate([ seqs[i][j] for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) if i < batch_size ], 0) for _ in range(gpu_num)] for j in range(feature_num)] seqs = [np.asarray([ np.pad(seqs[j][_], ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)), 'constant', constant_values=0) for _ in range(gpu_num)]) for j in range(feature_num)] batch[4] = np.asarray(batch_frames) batch[0] = seqs return batch def fit(self): if self.restore_iter != 0: self.load(self.restore_iter) self.encoder.train() self.sample_type = 'random' for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr triplet_sampler = TripletSampler(self.train_source, self.batch_size) train_loader = tordata.DataLoader( dataset=self.train_source, batch_sampler=triplet_sampler, collate_fn=self.collate_fn, num_workers=self.num_workers) train_label_set = list(self.train_source.label_set) train_label_set.sort() _time1 = datetime.now() for seq, view, seq_type, label, batch_frame in train_loader: self.restore_iter += 1 self.optimizer.zero_grad() for i in range(len(seq)): seq[i] = self.np2var(seq[i]).float() if batch_frame is not None: batch_frame = self.np2var(batch_frame).int() feature, label_prob = self.encoder(*seq, batch_frame) target_label = [train_label_set.index(l) for l in label] target_label = self.np2var(np.array(target_label)).long() triplet_feature = feature.permute(1, 0, 2).contiguous() triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1) (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num ) = self.triplet_loss(triplet_feature, triplet_label) if self.hard_or_full_trip == 'hard': loss = hard_loss_metric.mean() elif self.hard_or_full_trip == 'full': loss = full_loss_metric.mean() self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy()) self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy()) self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy()) self.dist_list.append(mean_dist.mean().data.cpu().numpy()) if loss > 1e-9: loss.backward() self.optimizer.step() if self.restore_iter % 1000 == 0: print(datetime.now() - _time1) _time1 = datetime.now() if self.restore_iter % 100 == 0: self.save() print('iter {}:'.format(self.restore_iter), end='') print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='') print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='') print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='') self.mean_dist = np.mean(self.dist_list) print(', mean_dist={0:.8f}'.format(self.mean_dist), end='') print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='') print(', hard or full=%r' % self.hard_or_full_trip) sys.stdout.flush() self.hard_loss_metric = [] self.full_loss_metric = [] self.full_loss_num = [] self.dist_list = [] # Visualization using t-SNE # if self.restore_iter % 500 == 0: # pca = TSNE(2) # pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy()) # for i in range(self.P): # plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0], # pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i]) # # plt.show() if self.restore_iter == self.total_iter: break def ts2var(self, x): return autograd.Variable(x).cuda() def np2var(self, x): return self.ts2var(torch.from_numpy(x)) def transform(self, flag, batch_size=1): self.encoder.eval() source = self.test_source if flag == 'test' else self.train_source self.sample_type = 'all' data_loader = tordata.DataLoader( dataset=source, batch_size=batch_size, sampler=tordata.sampler.SequentialSampler(source), collate_fn=self.collate_fn, num_workers=self.num_workers) feature_list = list() view_list = list() seq_type_list = list() label_list = list() for i, x in enumerate(data_loader): seq, view, seq_type, label, batch_frame = x for j in range(len(seq)): seq[j] = self.np2var(seq[j]).float() if batch_frame is not None: batch_frame = self.np2var(batch_frame).int() # print(batch_frame, np.sum(batch_frame)) feature, _ = self.encoder(*seq, batch_frame) n, num_bin, _ = feature.size() feature_list.append(feature.view(n, -1).data.cpu().numpy()) view_list += view seq_type_list += seq_type label_list += label return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list def save(self): os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True) torch.save(self.encoder.state_dict(), osp.join('checkpoint', self.model_name, '{}-{:0>5}-encoder.ptm'.format( self.save_name, self.restore_iter))) torch.save(self.optimizer.state_dict(), osp.join('checkpoint', self.model_name, '{}-{:0>5}-optimizer.ptm'.format( self.save_name, self.restore_iter))) # restore_iter: iteration index of the checkpoint to load def load(self, restore_iter): self.encoder.load_state_dict(torch.load(osp.join( 'checkpoint', self.model_name, '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)))) self.optimizer.load_state_dict(torch.load(osp.join( 'checkpoint', self.model_name, '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))
class Framework: batch_axis = 1 def __init__(self, config): self.config = config self.build_dataset() self.build_network() self.build_optimizer() def get_param(self, keys, default=None): node = self.config for key in keys.split('.'): if key in node: node = node[key] else: return default return node def cuda(self): self.model.cuda() if self.get_param('network.sync_bn', False): self.model = DataParallelWithCallback(self.model, dim=self.batch_axis) else: self.model = nn.DataParallel(self.model, dim=self.batch_axis) def build_optimizer(self): args = copy(self.config['optimizer']) if args['type'] == 'SGD': optim_class = torch.optim.SGD elif args['type'] == 'Adam': optim_class = torch.optim.Adam args.pop('type') if isinstance(args['lr'], list): args['lr'] = args['lr'][0][0] self.optimizer = optim_class(self.model.parameters(), **args) def set_learning_rate(self, epoch): lrs = self.config['optimizer']['lr'] if isinstance(lrs, list): c = 0 for lr, num_epochs in lrs: c += num_epochs if epoch <= c: break for param_group in self.optimizer.param_groups: param_group['lr'] = lr print('setting learning rate {}'.format(lr)) def build_dataset(self): length = self.config['data']['length'] stride = 5 self.train_data = dataset.build_ucf101_dataset( 'traindev1', transforms=transforms.Compose([ transforms.RandomResizedCrop((224, 224), (0.5, 1.0), ratio=(3 / 4, 4 / 3)), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]), length=length, stride=stride, config=self.config) self.train_loader = torch.utils.data.DataLoader( self.train_data, sampler=sampler.RandomSampler(self.train_data, loop=self.config['data']['loop']), batch_size=self.config['batch_size'], num_workers=self.config['num_worker']) self.val_data, self.val_loader = self.build_test_dataset('val1') self.classes = self.train_data.classes def build_test_dataset(self, split): length = self.config['data']['length'] stride = 5 data = dataset.build_ucf101_dataset(split, transforms=transforms.Compose([ transforms.CenterCrop( (224, 224)), transforms.ToTensor() ]), length=length, stride=stride, config=self.config) loader = torch.utils.data.DataLoader( data, sampler=sampler.ValSampler(data, stride=length), batch_size=self.config['batch_size'], num_workers=self.config['num_worker']) return data, loader # Why Non_Blocking def prepare_data(self, data): frames = torch.stack(data[0], dim=0).cuda(non_blocking=True) labels = data[1].cuda(non_blocking=True) vids = data[2].numpy() return (frames, labels), vids def train_epoch(self, epoch): self.model.train() self.set_learning_rate(epoch) end = time.time() metrics = defaultdict(AverageMeter) for i, data in enumerate(self.train_loader): # measure data loading time metrics['data_time'].update(time.time() - end) args, _ = self.prepare_data(data) batch_size = args[0].size(1) result = self.train_batch(*args) loss = result['loss'] for k, v in result.items(): metrics[k].update(v.item(), batch_size) # compute gradient and do SGD step self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time metrics['batch_time'].update(time.time() - end) end = time.time() if i % self.config['print_freq'] == 0: print('Epoch: [{0}][{1}/{2}]\t'.format(epoch, i, len(self.train_loader)), end='') for k, v in metrics.items(): print('{key} {val.avg:.3f}'.format(key=k, val=v), end='\t') print() # if i > 200: # break metrics.pop('batch_time') metrics.pop('data_time') return {k: v.avg for k, v in metrics.items()} def predict(self, dataloader): self.model.eval() metrics = defaultdict(list) with torch.no_grad(): for i, data in enumerate(dataloader): args, indices = self.prepare_data(data) result = self.predict_batch(*args) for k, v in result.items(): metrics[k].append(v.cpu().numpy()) metrics['indices'].append(indices) if i % self.config['print_freq'] == 0: print('Valid {}/{}'.format(i, len(dataloader))) for k in metrics: metrics[k] = np.concatenate(metrics[k], axis=0) return metrics def evaluate(self, dataloader): self.model.eval() metrics = defaultdict(AverageMeter) with torch.no_grad(): for i, data in enumerate(dataloader): args, indices = self.prepare_data(data) batch_size = args[0][0].size(0) result = self.eval_batch(*args) for k, v in result.items(): metrics[k].update(v.item(), batch_size) if i % self.config['print_freq'] == 0: print('Valid {}/{}'.format(i, len(dataloader))) return {k: v.avg for k, v in metrics.items()}
print(f"===> Hyper-parameters = {grid}:") # random.seed(seed) # numpy.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) importlib.reload(model) net = getattr(model, model_name.upper())(input_channel=input_channel, input_size=input_size, output_size=output_size).to(device) if use_cuda: net = DataParallelWithCallback(net) print(f"===> Model:\n{list(net.modules())[0]}") print_param(net) if model_name == 'cnn0' or model_name == 'cnn1': optimizer = Adam(net.parameters(), lr=grid['lr'], l1=grid['l1'], weight_decay=grid['l2'], amsgrad=True) elif model_name == 'vgg19' or model_name == 'cnn2' or model_name == 'cnn3': optimizer = Adam([{ 'params': iter(param for name, param in net.named_parameters() if 'channel_mask' in name), 'l1': grid['l1_channel'] }, { 'params': iter(param for name, param in net.named_parameters() if 'spatial_mask' in name),
class Trainer: def __init__(self, logger, checkpoint, device_ids, config): self.BtoA = config['cycle_loss_weight'] != 0 self.config = config self.logger = logger self.device_ids = device_ids self.restore(checkpoint) print("Generator...") print(self.generatorB) print("Discriminator...") print(self.discriminatorB) transform = list() transform.append(T.Resize(config['load_size'])) transform.append(T.RandomCrop(config['crop_size'])) transform.append(T.ToTensor()) transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = T.Compose(transform) self.dataset = ABDataset(config['root_dir'], partition='train', transform=transform) def restore(self, checkpoint): self.epoch = 0 self.generatorB = Generator(**self.config['generator_params']).cuda() self.generatorB = DataParallelWithCallback(self.generatorB, device_ids=self.device_ids) self.optimizer_generatorB = torch.optim.Adam( self.generatorB.parameters(), lr=self.config['lr_generator'], betas=(0.5, 0.999)) self.discriminatorB = Discriminator( **self.config['discriminator_params']).cuda() self.discriminatorB = DataParallelWithCallback( self.discriminatorB, device_ids=self.device_ids) self.optimizer_discriminatorB = torch.optim.Adam( self.discriminatorB.parameters(), lr=self.config['lr_discriminator'], betas=(0.5, 0.999)) if self.BtoA: self.generatorA = Generator( **self.config['generator_params']).cuda() self.generatorA = DataParallelWithCallback( self.generatorA, device_ids=self.device_ids) self.optimizer_generatorA = torch.optim.Adam( self.generatorA.parameters(), lr=self.config['lr_generator'], betas=(0.5, 0.999)) self.discriminatorA = Discriminator( **self.config['discriminator_params']).cuda() self.discriminatorA = DataParallelWithCallback( self.discriminatorA, device_ids=self.device_ids) self.optimizer_discriminatorA = torch.optim.Adam( self.discriminatorA.parameters(), lr=self.config['lr_discriminator'], betas=(0.5, 0.999)) if checkpoint is not None: data = torch.load(checkpoint) for key, value in data.items(): if key == 'epoch': self.epoch = value else: self.__dict__[key].load_state_dict(value) lr_lambda = lambda epoch: min( 1, 2 - 2 * epoch / self.config['num_epochs']) self.scheduler_generatorB = torch.optim.lr_scheduler.LambdaLR( self.optimizer_generatorB, lr_lambda, last_epoch=self.epoch - 1) self.scheduler_discriminatorB = torch.optim.lr_scheduler.LambdaLR( self.optimizer_discriminatorB, lr_lambda, last_epoch=self.epoch - 1) if self.BtoA: self.scheduler_generatorA = torch.optim.lr_scheduler.LambdaLR( self.optimizer_generatorA, lr_lambda, last_epoch=self.epoch - 1) self.scheduler_discriminatorA = torch.optim.lr_scheduler.LambdaLR( self.optimizer_discriminatorA, lr_lambda, last_epoch=self.epoch - 1) def save(self): state_dict = { 'epoch': self.epoch, 'generatorB': self.generatorB.state_dict(), 'optimizer_generatorB': self.optimizer_generatorB.state_dict(), 'discriminatorB': self.discriminatorB.state_dict(), 'optimizer_discriminatorB': self.optimizer_discriminatorB.state_dict() } if self.BtoA: state_dict.update({ 'generatorA': self.generatorA.state_dict(), 'optimizer_generatorA': self.optimizer_generatorA.state_dict(), 'discriminatorA': self.discriminatorA.state_dict(), 'optimizer_discriminatorA': self.optimizer_discriminatorA.state_dict() }) torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth')) def train(self): np.random.seed(0) loader = DataLoader(self.dataset, batch_size=self.config['bs'], shuffle=False, drop_last=True, num_workers=4) images_fixed = None for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])): loss_dict = defaultdict(lambda: 0.0) iteration_count = 1 for inp in tqdm(loader): images_A = inp['A'].cuda() images_B = inp['B'].cuda() if images_fixed is None: images_fixed = {'A': images_A, 'B': images_B} transform_fixed = Transform( images_A.shape[0], **self.config['transform_params']) if self.config['identity_loss_weight'] != 0: images_trg = self.generatorB(images_B, source=False) identity_loss = l1(images_trg, images_B) identity_loss = self.config[ 'identity_loss_weight'] * identity_loss identity_loss.backward() loss_dict['identity_loss_B'] += identity_loss.detach().cpu( ).numpy() if self.config['identity_loss_weight'] != 0 and self.BtoA: images_trg = self.generatorA(images_A, source=False) identity_loss = l1(images_trg, images_A) identity_loss = self.config[ 'identity_loss_weight'] * identity_loss identity_loss.backward() loss_dict['identity_loss_A'] += identity_loss.detach().cpu( ).numpy() generator_loss = 0 images_generatedB = self.generatorB(images_A, source=True) logits = self.discriminatorB(images_generatedB) adversarial_loss = gan_loss_generator( logits, self.config['gan_loss_type']) adversarial_loss = self.config[ 'adversarial_loss_weight'] * adversarial_loss generator_loss += adversarial_loss loss_dict['adversarial_loss_B'] += adversarial_loss.detach( ).cpu().numpy() if self.BtoA: images_generatedA = self.generatorA(images_B, source=True) logits = self.discriminatorA(images_generatedA) adversarial_loss = gan_loss_generator( logits, self.config['gan_loss_type']) adversarial_loss = self.config[ 'adversarial_loss_weight'] * adversarial_loss generator_loss += adversarial_loss loss_dict['adversarial_loss_A'] += adversarial_loss.detach( ).cpu().numpy() if self.config['equivariance_loss_weight_generator'] != 0: transform = Transform(images_generatedB.shape[0], **self.config['transform_params']) images_A_transformed = transform.transform_frame(images_A) loss = corr( self.generatorB(images_A_transformed, source=True), transform.transform_frame(images_generatedB)) loss = self.config[ 'equivariance_loss_weight_generator'] * loss generator_loss += loss loss_dict['equivariance_generator_B'] += loss.detach().cpu( ).numpy() if self.config[ 'equivariance_loss_weight_generator'] != 0 and self.BtoA: transform = Transform(images_generatedA.shape[0], **self.config['transform_params']) images_B_transformed = transform.transform_frame(images_B) loss = corr( self.generatorB(images_B_transformed, source=True), transform.transform_frame(images_generatedA)) loss = self.config[ 'equivariance_loss_weight_generator'] * loss generator_loss += loss loss_dict['equivariance_generator_A'] += loss.detach().cpu( ).numpy() if self.BtoA and self.config[ 'cycle_loss_weight'] != 0 and self.BtoA: images_cycled = self.generatorA(images_generatedB, source=True) cycle_loss = torch.abs(images_cycled - images_A).mean() cycle_loss = self.config['cycle_loss_weight'] * cycle_loss generator_loss += cycle_loss loss_dict['cycle_loss_B'] += cycle_loss.detach().cpu( ).numpy() images_cycled = self.generatorB(images_generatedA, source=True) cycle_loss = torch.abs(images_cycled - images_B).mean() cycle_loss = self.config['cycle_loss_weight'] * cycle_loss generator_loss += cycle_loss loss_dict['cycle_loss_A'] += cycle_loss.detach().cpu( ).numpy() generator_loss.backward() self.optimizer_generatorB.step() self.optimizer_generatorB.zero_grad() self.optimizer_discriminatorB.zero_grad() if self.BtoA: self.optimizer_generatorA.step() self.optimizer_generatorA.zero_grad() self.optimizer_discriminatorA.zero_grad() logits_fake = self.discriminatorB(images_generatedB.detach()) logits_real = self.discriminatorB(images_B) discriminator_loss = gan_loss_discriminator( logits_real, logits_fake, self.config['gan_loss_type']) loss_dict['discriminator_loss_B'] += discriminator_loss.detach( ).cpu().numpy() if self.config['equivariance_loss_weight_discriminator'] != 0: images_join = torch.cat( [images_generatedB.detach(), images_B]) logits_join = torch.cat([logits_fake, logits_real]) transform = Transform(images_join.shape[0], **self.config['transform_params']) images_transformed = transform.transform_frame(images_join) loss = corr(self.discriminatorB(images_transformed), transform.transform_frame(logits_join)) loss = self.config[ 'equivariance_loss_weight_discriminator'] * loss discriminator_loss += loss loss_dict['equivariance_discriminator_B'] += loss.detach( ).cpu().numpy() discriminator_loss.backward() self.optimizer_discriminatorB.step() self.optimizer_discriminatorB.zero_grad() self.optimizer_generatorB.zero_grad() if self.BtoA: logits_fake = self.discriminatorA( images_generatedA.detach()) logits_real = self.discriminatorA(images_A) discriminator_loss = gan_loss_discriminator( logits_real, logits_fake, self.config['gan_loss_type']) loss_dict[ 'discriminator_loss_B'] += discriminator_loss.detach( ).cpu().numpy() if self.config[ 'equivariance_loss_weight_discriminator'] != 0: images_join = torch.cat( [images_generatedA.detach(), images_A]) logits_join = torch.cat([logits_fake, logits_real]) transform = Transform( images_join.shape[0], **self.config['transform_params']) images_transformed = transform.transform_frame( images_join) loss = corr(self.discriminatorA(images_transformed), transform.transform_frame(logits_join)) loss = self.config[ 'equivariance_loss_weight_discriminator'] * loss discriminator_loss += loss loss_dict[ 'equivariance_discriminator_B'] += loss.detach( ).cpu().numpy() discriminator_loss.backward() self.optimizer_discriminatorA.step() self.optimizer_discriminatorA.zero_grad() self.optimizer_generatorA.zero_grad() iteration_count += 1 with torch.no_grad(): if not self.BtoA: self.generatorB.eval() transformed = transform_fixed.transform_frame( images_fixed['A']) self.logger.save_images( self.epoch, images_fixed['A'], self.generatorB(images_fixed['A'], source=True), transformed, self.generatorB(transformed, source=True)) self.generatorB.train() else: self.generatorA.eval() self.generatorB.eval() images_generatedB = self.generatorB(images_fixed['A'], source=True) images_generatedA = self.generatorA(images_fixed['B'], source=True) transformed = transform_fixed.transform_frame( images_fixed['A']) self.logger.save_images( self.epoch, images_fixed['A'], images_generatedB, transformed, self.generatorB(transformed, source=True), self.generatorA(images_generatedB, source=True), images_fixed['B'], images_generatedA, self.generatorB(images_generatedA, source=True)) self.generatorA.train() self.generatorB.train() self.scheduler_generatorB.step() self.scheduler_discriminatorB.step() if self.BtoA: self.scheduler_generatorA.step() self.scheduler_discriminatorA.step() save_dict = { key: value / iteration_count for key, value in loss_dict.items() } save_dict['lr'] = self.optimizer_generatorB.param_groups[0]['lr'] self.logger.log(self.epoch, save_dict) self.save()