class Trainer: def __init__(self, logger, checkpoint, device_ids, config): self.config = config self.logger = logger self.device_ids = device_ids self.dataset, n_classes = get_dataset(config['dataset'], config['dataset_params']) if self.config['with_labels']: self.config['generator_params']['n_classes'] = n_classes self.config['discriminator_params']['n_classes'] = n_classes self.config['n_classes'] = n_classes else: self.config['generator_params']['n_classes'] = None self.config['discriminator_params']['n_classes'] = None self.restore(checkpoint) print("Generator...") print(self.generator) print("Discriminator...") print(self.discriminator) def restore(self, checkpoint): self.epoch = 0 self.generator = DCGenerator(**self.config['generator_params']) self.generator = DataParallelWithCallback(self.generator, device_ids=self.device_ids) self.optimizer_generator = torch.optim.Adam( params=self.generator.parameters(), lr=self.config['lr_generator'], betas=(self.config['b1_generator'], self.config['b2_generator']), weight_decay=0, eps=1e-8) self.discriminator = DCDiscriminator( **self.config['discriminator_params']) self.discriminator = DataParallelWithCallback( self.discriminator, device_ids=self.device_ids) self.optimizer_discriminator = torch.optim.Adam( params=self.discriminator.parameters(), lr=self.config['lr_discriminator'], betas=(self.config['b1_discriminator'], self.config['b2_discriminator']), weight_decay=0, eps=1e-8) if checkpoint is not None: data = torch.load(checkpoint) for key, value in data: if key == 'epoch': self.epoch = value else: self.__dict__[key].load_state_dict(value) lr_lambda = lambda epoch: 1 - epoch / self.config['num_epochs'] self.scheduler_generator = torch.optim.lr_scheduler.LambdaLR( self.optimizer_generator, lr_lambda, last_epoch=self.epoch - 1) self.scheduler_discriminator = torch.optim.lr_scheduler.LambdaLR( self.optimizer_discriminator, lr_lambda, last_epoch=self.epoch - 1) def save(self): state_dict = { 'epoch': self.epoch, 'generator': self.generator.state_dict(), 'optimizer_generator': self.optimizer_generator.state_dict(), 'discriminator': self.discriminator.state_dict(), 'optimizer_discriminator': self.optimizer_discriminator.state_dict() } torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth')) def train(self): loader = DataLoader(self.dataset, batch_size=self.config['discriminator_bs'], shuffle=False, drop_last=True, num_workers=self.config['num_workers']) noise = torch.zeros((max(self.config['generator_bs'], self.config['discriminator_bs']), self.config['generator_params']['dim_z'])).cuda() if self.config['with_labels']: labels_fake = torch.zeros( max(self.config['generator_bs'], self.config['discriminator_bs'])).type( torch.LongTensor).cuda() else: labels_fake = None y_fake = None # Keep track of current iteration for update generator current_iter = 0 loss_dict = defaultdict(lambda: 0.0) for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])): for data in tqdm(loader): self.generator.train() current_iter += 1 images, labels_real = data y_real = None if not self.config['with_labels'] else labels_real self.optimizer_generator.zero_grad() self.optimizer_discriminator.zero_grad() z = noise.normal_()[:self.config['discriminator_bs']] if self.config['with_labels']: y_fake = labels_fake.random_( self.config['n_classes'])[:self. config['discriminator_bs']] with torch.no_grad(): images_fake = self.generator(z, y_fake) logits_real = self.discriminator(images, y_real) logits_fake = self.discriminator(images_fake, y_fake) loss_fake = torch.relu(1 + logits_fake).mean() loss_real = torch.relu(1 - logits_real).mean() loss_dict['loss_fake'] += loss_fake.detach().cpu().numpy() loss_dict['loss_real'] += loss_real.detach().cpu().numpy() (loss_fake + loss_real).backward() self.optimizer_discriminator.step() if current_iter % self.config['num_discriminator_updates'] == 0: self.optimizer_discriminator.zero_grad() self.optimizer_generator.zero_grad() z = noise.normal_()[:self.config['generator_bs']] if self.config['with_labels']: y_fake = labels_fake.random_( self.config['n_classes'])[:self. config['generator_bs']] images_fake = self.generator(z, y_fake) logits_fake = self.discriminator(images_fake, y_fake) adversarial_loss = -logits_fake.mean() loss_dict['adversarial_loss'] += adversarial_loss.detach( ).cpu().numpy() adversarial_loss.backward() self.optimizer_generator.step() save_dict = { key: value / current_iter for key, value in loss_dict.items() } save_dict['lr'] = self.optimizer_generator.param_groups[0]['lr'] loss_dict = defaultdict(lambda: 0.0) current_iter = 0 with torch.no_grad(): noise = noise.normal_() if self.config['with_labels']: labels_fake = labels_fake.random_(self.config['n_classes']) images = self.generator(noise, labels_fake) self.logger.save_images(self.epoch, images) # if self.epoch % self.config['eval_frequency'] == 0 or self.epoch == self.config['num_epochs'] - 1: # self.generator.eval() # # if self.config['samples_evaluation'] != 0: # generated = [] # with torch.no_grad(): # for i in range(self.config['samples_evaluation'] // noise.shape[0] + 1): # noise = noise.normal_() # if self.config['with_labels']: # labels_fake = labels_fake.random_(self.config['n_classes']) # # generated.append((127.5 * self.generator(noise, labels_fake) + 127.5).cpu().numpy()) # # generated = np.concatenate(generated)[:self.config['samples_evaluation']] # self.logger.save_evaluation_images(self.epoch, generated) self.logger.log(self.epoch, save_dict) self.scheduler_generator.step() self.scheduler_discriminator.step() self.save()
def debug_generator(generator, kp_to_skl_gt, loader, train_params, logger, device_ids, tgt_batch=None): log_params = train_params['log_params'] genModel = ConditionalGenerator2D(generator, train_params) genModel = DataParallelWithCallback(genModel, device_ids=device_ids) optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr'], betas=train_params['betas']) scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1) k=0 train_views = [0,1,3] eval_views = [2] if tgt_batch is not None: tgt_batch_samples = split_data(tgt_batch, train_views=train_views, eval_views=eval_views) with torch.no_grad(): tgt_batch_samples['gt_skl'] = kp_to_skl_gt(tgt_batch_samples['kps'].to('cuda')).unsqueeze(1) tgt_batch_samples['gt_skl_eval'] = kp_to_skl_gt(tgt_batch_samples['kps_eval'].to('cuda')).unsqueeze(1) for epoch in range(train_params['num_epochs']): for i, batch in enumerate(tqdm(loader)): batch_samples = split_data((img, annots, ref_img), train_views=train_views, eval_views=eval_views) #imgs = flatten_views(imgs) #ref_imgs = flatten_views(ref_imgs) #ref_imgs = torch.rand(*imgs.shape) with torch.no_grad(): batch_samples['gt_skl'] = kp_to_skl_gt(batch_samples['kps'].to('cuda')).unsqueeze(1) #batch_samples['gt_skl_eval'] = kp_to_skl_gt(batch_samples['kps_eval'].to('cuda')).unsqueeze(1) #gt_skl = (kp_to_skl_gt(flatten_views(annots / (ref_img.shape[3] - 1)).to('cuda'))).unsqueeze(1) #gt_skl = torch.rand(imgs.shape[0], 1, *imgs.shape[2:]) #generator_out = genModel(imgs, ref_imgs, gt_skl) generator_out = genModel(batch_samples['imgs'], batch_samples['ref_imgs'], batch_samples['gt_skl']) ##### Generator update #loss_generator = generator_out['loss'] loss_generator = generator_out['perceptual_loss'] loss_generator = [x.mean() for x in loss_generator] loss_gen = sum(loss_generator) loss_gen.backward(retain_graph=not train_params['detach_kp_discriminator']) optimizer_generator.step() optimizer_generator.zero_grad() ########### LOG logger.add_scalar("Generator Loss", loss_gen.item(), epoch * len(loader) + i + 1) if i in log_params['log_imgs']: if tgt_batch is not None: with torch.no_grad(): genModel.eval() generator_out_eval = genModel(tgt_batch_samples['imgs_eval'], tgt_batch_samples['ref_imgs_eval'], tgt_batch_samples['gt_skl_eval']) #generator_out_eval = genModel(batch_samples['imgs_eval'], # batch_samples['ref_imgs_eval'], # batch_samples['gt_skl_eval']) concat_img_eval = np.concatenate((tensor_to_image(tgt_batch_samples['imgs_eval'][k]), tensor_to_image(tgt_batch_samples['gt_skl_eval'][k]), tensor_to_image(tgt_batch_samples['ref_imgs_eval'][k]), tensor_to_image(generator_out_eval['reconstructred_image'][k])), axis=2) # concat along width logger.add_image('Sample_{%d}_EVAL' % i, concat_img_eval, epoch) genModel.train() k += 1 k = k % 4 concat_img = np.concatenate((tensor_to_image(batch_samples['imgs'][k]), tensor_to_image(batch_samples['gt_skl'][k]), tensor_to_image(batch_samples['ref_imgs'][k]), tensor_to_image(generator_out['reconstructred_image'][k])), axis=2) # concat along width logger.add_image('Sample_{%d}' % i, concat_img, epoch) scheduler_generator.step()
param_group['lr'] = opt.LR[0] elif group_id == 1: param_group['lr'] = opt.LR[0] elif group_id == 2: param_group['lr'] = opt.LR[1] resume_epoch = checkpoint['epoch'] if 'step' in checkpoint: resume_step = checkpoint['step'] + 1 if 'max_acc' in checkpoint: max_test_acc = checkpoint['max_acc'] print('Finish Loading') del checkpoint # ########################################################################### # training and testing model.train() predict_for_mAP = [] label_for_mAP = [] print("START") KineticsLoader = torch.utils.data.DataLoader(Kinetic_train_dataset.Kinetics( video_path=opt.video_path + '/train_frames', of_path=opt.video_path + '/train_ofs', frame_num=opt.frame_num, batch_size=opt.batch_size, img_size=opt.img_size, slice_num=opt.slice_num, overlap_rate=opt.overlap_rate), batch_size=1, shuffle=True,
class Model: def __init__(self, hidden_dim, lr, hard_or_full_trip, margin, num_workers, batch_size, restore_iter, total_iter, save_name, train_pid_num, frame_num, model_name, train_source, test_source, img_size=64): self.save_name = save_name self.train_pid_num = train_pid_num self.train_source = train_source self.test_source = test_source self.hidden_dim = hidden_dim self.lr = lr self.hard_or_full_trip = hard_or_full_trip self.margin = margin self.frame_num = frame_num self.num_workers = num_workers self.batch_size = batch_size self.model_name = model_name self.P, self.M = batch_size self.restore_iter = restore_iter self.total_iter = total_iter self.img_size = img_size self.encoder = SetNet(self.hidden_dim).float() self.encoder = DataParallelWithCallback(self.encoder) self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float() self.triplet_loss = DataParallelWithCallback(self.triplet_loss) self.encoder.cuda() self.triplet_loss.cuda() self.optimizer = optim.Adam([ {'params': self.encoder.parameters()}, ], lr=self.lr) self.hard_loss_metric = [] self.full_loss_metric = [] self.full_loss_num = [] self.dist_list = [] self.mean_dist = 0.01 self.sample_type = 'all' def collate_fn(self, batch): batch_size = len(batch) feature_num = len(batch[0][0]) seqs = [batch[i][0] for i in range(batch_size)] frame_sets = [batch[i][1] for i in range(batch_size)] view = [batch[i][2] for i in range(batch_size)] seq_type = [batch[i][3] for i in range(batch_size)] label = [batch[i][4] for i in range(batch_size)] batch = [seqs, view, seq_type, label, None] def select_frame(index): sample = seqs[index] frame_set = frame_sets[index] if self.sample_type == 'random': frame_id_list = random.choices(frame_set, k=self.frame_num) _ = [feature.loc[frame_id_list].values for feature in sample] else: _ = [feature.values for feature in sample] return _ seqs = list(map(select_frame, range(len(seqs)))) if self.sample_type == 'random': seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)] else: gpu_num = min(torch.cuda.device_count(), batch_size) batch_per_gpu = math.ceil(batch_size / gpu_num) batch_frames = [[ len(frame_sets[i]) for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) if i < batch_size ] for _ in range(gpu_num)] if len(batch_frames[-1]) != batch_per_gpu: for _ in range(batch_per_gpu - len(batch_frames[-1])): batch_frames[-1].append(0) max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)]) seqs = [[ np.concatenate([ seqs[i][j] for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) if i < batch_size ], 0) for _ in range(gpu_num)] for j in range(feature_num)] seqs = [np.asarray([ np.pad(seqs[j][_], ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)), 'constant', constant_values=0) for _ in range(gpu_num)]) for j in range(feature_num)] batch[4] = np.asarray(batch_frames) batch[0] = seqs return batch def fit(self): if self.restore_iter != 0: self.load(self.restore_iter) self.encoder.train() self.sample_type = 'random' for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr triplet_sampler = TripletSampler(self.train_source, self.batch_size) train_loader = tordata.DataLoader( dataset=self.train_source, batch_sampler=triplet_sampler, collate_fn=self.collate_fn, num_workers=self.num_workers) train_label_set = list(self.train_source.label_set) train_label_set.sort() _time1 = datetime.now() for seq, view, seq_type, label, batch_frame in train_loader: self.restore_iter += 1 self.optimizer.zero_grad() for i in range(len(seq)): seq[i] = self.np2var(seq[i]).float() if batch_frame is not None: batch_frame = self.np2var(batch_frame).int() feature, label_prob = self.encoder(*seq, batch_frame) target_label = [train_label_set.index(l) for l in label] target_label = self.np2var(np.array(target_label)).long() triplet_feature = feature.permute(1, 0, 2).contiguous() triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1) (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num ) = self.triplet_loss(triplet_feature, triplet_label) if self.hard_or_full_trip == 'hard': loss = hard_loss_metric.mean() elif self.hard_or_full_trip == 'full': loss = full_loss_metric.mean() self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy()) self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy()) self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy()) self.dist_list.append(mean_dist.mean().data.cpu().numpy()) if loss > 1e-9: loss.backward() self.optimizer.step() if self.restore_iter % 1000 == 0: print(datetime.now() - _time1) _time1 = datetime.now() if self.restore_iter % 100 == 0: self.save() print('iter {}:'.format(self.restore_iter), end='') print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='') print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='') print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='') self.mean_dist = np.mean(self.dist_list) print(', mean_dist={0:.8f}'.format(self.mean_dist), end='') print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='') print(', hard or full=%r' % self.hard_or_full_trip) sys.stdout.flush() self.hard_loss_metric = [] self.full_loss_metric = [] self.full_loss_num = [] self.dist_list = [] # Visualization using t-SNE # if self.restore_iter % 500 == 0: # pca = TSNE(2) # pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy()) # for i in range(self.P): # plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0], # pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i]) # # plt.show() if self.restore_iter == self.total_iter: break def ts2var(self, x): return autograd.Variable(x).cuda() def np2var(self, x): return self.ts2var(torch.from_numpy(x)) def transform(self, flag, batch_size=1): self.encoder.eval() source = self.test_source if flag == 'test' else self.train_source self.sample_type = 'all' data_loader = tordata.DataLoader( dataset=source, batch_size=batch_size, sampler=tordata.sampler.SequentialSampler(source), collate_fn=self.collate_fn, num_workers=self.num_workers) feature_list = list() view_list = list() seq_type_list = list() label_list = list() for i, x in enumerate(data_loader): seq, view, seq_type, label, batch_frame = x for j in range(len(seq)): seq[j] = self.np2var(seq[j]).float() if batch_frame is not None: batch_frame = self.np2var(batch_frame).int() # print(batch_frame, np.sum(batch_frame)) feature, _ = self.encoder(*seq, batch_frame) n, num_bin, _ = feature.size() feature_list.append(feature.view(n, -1).data.cpu().numpy()) view_list += view seq_type_list += seq_type label_list += label return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list def save(self): os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True) torch.save(self.encoder.state_dict(), osp.join('checkpoint', self.model_name, '{}-{:0>5}-encoder.ptm'.format( self.save_name, self.restore_iter))) torch.save(self.optimizer.state_dict(), osp.join('checkpoint', self.model_name, '{}-{:0>5}-optimizer.ptm'.format( self.save_name, self.restore_iter))) # restore_iter: iteration index of the checkpoint to load def load(self, restore_iter): self.encoder.load_state_dict(torch.load(osp.join( 'checkpoint', self.model_name, '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)))) self.optimizer.load_state_dict(torch.load(osp.join( 'checkpoint', self.model_name, '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))
class Framework: batch_axis = 1 def __init__(self, config): self.config = config self.build_dataset() self.build_network() self.build_optimizer() def get_param(self, keys, default=None): node = self.config for key in keys.split('.'): if key in node: node = node[key] else: return default return node def cuda(self): self.model.cuda() if self.get_param('network.sync_bn', False): self.model = DataParallelWithCallback(self.model, dim=self.batch_axis) else: self.model = nn.DataParallel(self.model, dim=self.batch_axis) def build_optimizer(self): args = copy(self.config['optimizer']) if args['type'] == 'SGD': optim_class = torch.optim.SGD elif args['type'] == 'Adam': optim_class = torch.optim.Adam args.pop('type') if isinstance(args['lr'], list): args['lr'] = args['lr'][0][0] self.optimizer = optim_class(self.model.parameters(), **args) def set_learning_rate(self, epoch): lrs = self.config['optimizer']['lr'] if isinstance(lrs, list): c = 0 for lr, num_epochs in lrs: c += num_epochs if epoch <= c: break for param_group in self.optimizer.param_groups: param_group['lr'] = lr print('setting learning rate {}'.format(lr)) def build_dataset(self): length = self.config['data']['length'] stride = 5 self.train_data = dataset.build_ucf101_dataset( 'traindev1', transforms=transforms.Compose([ transforms.RandomResizedCrop((224, 224), (0.5, 1.0), ratio=(3 / 4, 4 / 3)), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]), length=length, stride=stride, config=self.config) self.train_loader = torch.utils.data.DataLoader( self.train_data, sampler=sampler.RandomSampler(self.train_data, loop=self.config['data']['loop']), batch_size=self.config['batch_size'], num_workers=self.config['num_worker']) self.val_data, self.val_loader = self.build_test_dataset('val1') self.classes = self.train_data.classes def build_test_dataset(self, split): length = self.config['data']['length'] stride = 5 data = dataset.build_ucf101_dataset(split, transforms=transforms.Compose([ transforms.CenterCrop( (224, 224)), transforms.ToTensor() ]), length=length, stride=stride, config=self.config) loader = torch.utils.data.DataLoader( data, sampler=sampler.ValSampler(data, stride=length), batch_size=self.config['batch_size'], num_workers=self.config['num_worker']) return data, loader # Why Non_Blocking def prepare_data(self, data): frames = torch.stack(data[0], dim=0).cuda(non_blocking=True) labels = data[1].cuda(non_blocking=True) vids = data[2].numpy() return (frames, labels), vids def train_epoch(self, epoch): self.model.train() self.set_learning_rate(epoch) end = time.time() metrics = defaultdict(AverageMeter) for i, data in enumerate(self.train_loader): # measure data loading time metrics['data_time'].update(time.time() - end) args, _ = self.prepare_data(data) batch_size = args[0].size(1) result = self.train_batch(*args) loss = result['loss'] for k, v in result.items(): metrics[k].update(v.item(), batch_size) # compute gradient and do SGD step self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time metrics['batch_time'].update(time.time() - end) end = time.time() if i % self.config['print_freq'] == 0: print('Epoch: [{0}][{1}/{2}]\t'.format(epoch, i, len(self.train_loader)), end='') for k, v in metrics.items(): print('{key} {val.avg:.3f}'.format(key=k, val=v), end='\t') print() # if i > 200: # break metrics.pop('batch_time') metrics.pop('data_time') return {k: v.avg for k, v in metrics.items()} def predict(self, dataloader): self.model.eval() metrics = defaultdict(list) with torch.no_grad(): for i, data in enumerate(dataloader): args, indices = self.prepare_data(data) result = self.predict_batch(*args) for k, v in result.items(): metrics[k].append(v.cpu().numpy()) metrics['indices'].append(indices) if i % self.config['print_freq'] == 0: print('Valid {}/{}'.format(i, len(dataloader))) for k in metrics: metrics[k] = np.concatenate(metrics[k], axis=0) return metrics def evaluate(self, dataloader): self.model.eval() metrics = defaultdict(AverageMeter) with torch.no_grad(): for i, data in enumerate(dataloader): args, indices = self.prepare_data(data) batch_size = args[0][0].size(0) result = self.eval_batch(*args) for k, v in result.items(): metrics[k].update(v.item(), batch_size) if i % self.config['print_freq'] == 0: print('Valid {}/{}'.format(i, len(dataloader))) return {k: v.avg for k, v in metrics.items()}
class Trainer: def __init__(self, logger, checkpoint, device_ids, config): self.BtoA = config['cycle_loss_weight'] != 0 self.config = config self.logger = logger self.device_ids = device_ids self.restore(checkpoint) print("Generator...") print(self.generatorB) print("Discriminator...") print(self.discriminatorB) transform = list() transform.append(T.Resize(config['load_size'])) transform.append(T.RandomCrop(config['crop_size'])) transform.append(T.ToTensor()) transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = T.Compose(transform) self.dataset = ABDataset(config['root_dir'], partition='train', transform=transform) def restore(self, checkpoint): self.epoch = 0 self.generatorB = Generator(**self.config['generator_params']).cuda() self.generatorB = DataParallelWithCallback(self.generatorB, device_ids=self.device_ids) self.optimizer_generatorB = torch.optim.Adam( self.generatorB.parameters(), lr=self.config['lr_generator'], betas=(0.5, 0.999)) self.discriminatorB = Discriminator( **self.config['discriminator_params']).cuda() self.discriminatorB = DataParallelWithCallback( self.discriminatorB, device_ids=self.device_ids) self.optimizer_discriminatorB = torch.optim.Adam( self.discriminatorB.parameters(), lr=self.config['lr_discriminator'], betas=(0.5, 0.999)) if self.BtoA: self.generatorA = Generator( **self.config['generator_params']).cuda() self.generatorA = DataParallelWithCallback( self.generatorA, device_ids=self.device_ids) self.optimizer_generatorA = torch.optim.Adam( self.generatorA.parameters(), lr=self.config['lr_generator'], betas=(0.5, 0.999)) self.discriminatorA = Discriminator( **self.config['discriminator_params']).cuda() self.discriminatorA = DataParallelWithCallback( self.discriminatorA, device_ids=self.device_ids) self.optimizer_discriminatorA = torch.optim.Adam( self.discriminatorA.parameters(), lr=self.config['lr_discriminator'], betas=(0.5, 0.999)) if checkpoint is not None: data = torch.load(checkpoint) for key, value in data.items(): if key == 'epoch': self.epoch = value else: self.__dict__[key].load_state_dict(value) lr_lambda = lambda epoch: min( 1, 2 - 2 * epoch / self.config['num_epochs']) self.scheduler_generatorB = torch.optim.lr_scheduler.LambdaLR( self.optimizer_generatorB, lr_lambda, last_epoch=self.epoch - 1) self.scheduler_discriminatorB = torch.optim.lr_scheduler.LambdaLR( self.optimizer_discriminatorB, lr_lambda, last_epoch=self.epoch - 1) if self.BtoA: self.scheduler_generatorA = torch.optim.lr_scheduler.LambdaLR( self.optimizer_generatorA, lr_lambda, last_epoch=self.epoch - 1) self.scheduler_discriminatorA = torch.optim.lr_scheduler.LambdaLR( self.optimizer_discriminatorA, lr_lambda, last_epoch=self.epoch - 1) def save(self): state_dict = { 'epoch': self.epoch, 'generatorB': self.generatorB.state_dict(), 'optimizer_generatorB': self.optimizer_generatorB.state_dict(), 'discriminatorB': self.discriminatorB.state_dict(), 'optimizer_discriminatorB': self.optimizer_discriminatorB.state_dict() } if self.BtoA: state_dict.update({ 'generatorA': self.generatorA.state_dict(), 'optimizer_generatorA': self.optimizer_generatorA.state_dict(), 'discriminatorA': self.discriminatorA.state_dict(), 'optimizer_discriminatorA': self.optimizer_discriminatorA.state_dict() }) torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth')) def train(self): np.random.seed(0) loader = DataLoader(self.dataset, batch_size=self.config['bs'], shuffle=False, drop_last=True, num_workers=4) images_fixed = None for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])): loss_dict = defaultdict(lambda: 0.0) iteration_count = 1 for inp in tqdm(loader): images_A = inp['A'].cuda() images_B = inp['B'].cuda() if images_fixed is None: images_fixed = {'A': images_A, 'B': images_B} transform_fixed = Transform( images_A.shape[0], **self.config['transform_params']) if self.config['identity_loss_weight'] != 0: images_trg = self.generatorB(images_B, source=False) identity_loss = l1(images_trg, images_B) identity_loss = self.config[ 'identity_loss_weight'] * identity_loss identity_loss.backward() loss_dict['identity_loss_B'] += identity_loss.detach().cpu( ).numpy() if self.config['identity_loss_weight'] != 0 and self.BtoA: images_trg = self.generatorA(images_A, source=False) identity_loss = l1(images_trg, images_A) identity_loss = self.config[ 'identity_loss_weight'] * identity_loss identity_loss.backward() loss_dict['identity_loss_A'] += identity_loss.detach().cpu( ).numpy() generator_loss = 0 images_generatedB = self.generatorB(images_A, source=True) logits = self.discriminatorB(images_generatedB) adversarial_loss = gan_loss_generator( logits, self.config['gan_loss_type']) adversarial_loss = self.config[ 'adversarial_loss_weight'] * adversarial_loss generator_loss += adversarial_loss loss_dict['adversarial_loss_B'] += adversarial_loss.detach( ).cpu().numpy() if self.BtoA: images_generatedA = self.generatorA(images_B, source=True) logits = self.discriminatorA(images_generatedA) adversarial_loss = gan_loss_generator( logits, self.config['gan_loss_type']) adversarial_loss = self.config[ 'adversarial_loss_weight'] * adversarial_loss generator_loss += adversarial_loss loss_dict['adversarial_loss_A'] += adversarial_loss.detach( ).cpu().numpy() if self.config['equivariance_loss_weight_generator'] != 0: transform = Transform(images_generatedB.shape[0], **self.config['transform_params']) images_A_transformed = transform.transform_frame(images_A) loss = corr( self.generatorB(images_A_transformed, source=True), transform.transform_frame(images_generatedB)) loss = self.config[ 'equivariance_loss_weight_generator'] * loss generator_loss += loss loss_dict['equivariance_generator_B'] += loss.detach().cpu( ).numpy() if self.config[ 'equivariance_loss_weight_generator'] != 0 and self.BtoA: transform = Transform(images_generatedA.shape[0], **self.config['transform_params']) images_B_transformed = transform.transform_frame(images_B) loss = corr( self.generatorB(images_B_transformed, source=True), transform.transform_frame(images_generatedA)) loss = self.config[ 'equivariance_loss_weight_generator'] * loss generator_loss += loss loss_dict['equivariance_generator_A'] += loss.detach().cpu( ).numpy() if self.BtoA and self.config[ 'cycle_loss_weight'] != 0 and self.BtoA: images_cycled = self.generatorA(images_generatedB, source=True) cycle_loss = torch.abs(images_cycled - images_A).mean() cycle_loss = self.config['cycle_loss_weight'] * cycle_loss generator_loss += cycle_loss loss_dict['cycle_loss_B'] += cycle_loss.detach().cpu( ).numpy() images_cycled = self.generatorB(images_generatedA, source=True) cycle_loss = torch.abs(images_cycled - images_B).mean() cycle_loss = self.config['cycle_loss_weight'] * cycle_loss generator_loss += cycle_loss loss_dict['cycle_loss_A'] += cycle_loss.detach().cpu( ).numpy() generator_loss.backward() self.optimizer_generatorB.step() self.optimizer_generatorB.zero_grad() self.optimizer_discriminatorB.zero_grad() if self.BtoA: self.optimizer_generatorA.step() self.optimizer_generatorA.zero_grad() self.optimizer_discriminatorA.zero_grad() logits_fake = self.discriminatorB(images_generatedB.detach()) logits_real = self.discriminatorB(images_B) discriminator_loss = gan_loss_discriminator( logits_real, logits_fake, self.config['gan_loss_type']) loss_dict['discriminator_loss_B'] += discriminator_loss.detach( ).cpu().numpy() if self.config['equivariance_loss_weight_discriminator'] != 0: images_join = torch.cat( [images_generatedB.detach(), images_B]) logits_join = torch.cat([logits_fake, logits_real]) transform = Transform(images_join.shape[0], **self.config['transform_params']) images_transformed = transform.transform_frame(images_join) loss = corr(self.discriminatorB(images_transformed), transform.transform_frame(logits_join)) loss = self.config[ 'equivariance_loss_weight_discriminator'] * loss discriminator_loss += loss loss_dict['equivariance_discriminator_B'] += loss.detach( ).cpu().numpy() discriminator_loss.backward() self.optimizer_discriminatorB.step() self.optimizer_discriminatorB.zero_grad() self.optimizer_generatorB.zero_grad() if self.BtoA: logits_fake = self.discriminatorA( images_generatedA.detach()) logits_real = self.discriminatorA(images_A) discriminator_loss = gan_loss_discriminator( logits_real, logits_fake, self.config['gan_loss_type']) loss_dict[ 'discriminator_loss_B'] += discriminator_loss.detach( ).cpu().numpy() if self.config[ 'equivariance_loss_weight_discriminator'] != 0: images_join = torch.cat( [images_generatedA.detach(), images_A]) logits_join = torch.cat([logits_fake, logits_real]) transform = Transform( images_join.shape[0], **self.config['transform_params']) images_transformed = transform.transform_frame( images_join) loss = corr(self.discriminatorA(images_transformed), transform.transform_frame(logits_join)) loss = self.config[ 'equivariance_loss_weight_discriminator'] * loss discriminator_loss += loss loss_dict[ 'equivariance_discriminator_B'] += loss.detach( ).cpu().numpy() discriminator_loss.backward() self.optimizer_discriminatorA.step() self.optimizer_discriminatorA.zero_grad() self.optimizer_generatorA.zero_grad() iteration_count += 1 with torch.no_grad(): if not self.BtoA: self.generatorB.eval() transformed = transform_fixed.transform_frame( images_fixed['A']) self.logger.save_images( self.epoch, images_fixed['A'], self.generatorB(images_fixed['A'], source=True), transformed, self.generatorB(transformed, source=True)) self.generatorB.train() else: self.generatorA.eval() self.generatorB.eval() images_generatedB = self.generatorB(images_fixed['A'], source=True) images_generatedA = self.generatorA(images_fixed['B'], source=True) transformed = transform_fixed.transform_frame( images_fixed['A']) self.logger.save_images( self.epoch, images_fixed['A'], images_generatedB, transformed, self.generatorB(transformed, source=True), self.generatorA(images_generatedB, source=True), images_fixed['B'], images_generatedA, self.generatorB(images_generatedA, source=True)) self.generatorA.train() self.generatorB.train() self.scheduler_generatorB.step() self.scheduler_discriminatorB.step() if self.BtoA: self.scheduler_generatorA.step() self.scheduler_discriminatorA.step() save_dict = { key: value / iteration_count for key, value in loss_dict.items() } save_dict['lr'] = self.optimizer_generatorB.param_groups[0]['lr'] self.logger.log(self.epoch, save_dict) self.save()
def main(): torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--LR', type=list, default=[1e-4, 1e-4], help='learning rate') # start from 1e-4 parser.add_argument('--EPOCH', type=int, default=30, help='epoch') parser.add_argument('--slice_num', type=int, default=6, help='how many slices to cut') parser.add_argument('--batch_size', type=int, default=40, help='batch_size') parser.add_argument('--frame_num', type=int, default=5, help='how many frames in a slice') parser.add_argument('--model_path', type=str, default='/Disk1/poli/models/DeepRNN/Kinetics_res18', help='model_path') parser.add_argument('--model_name', type=str, default='checkpoint', help='model name') parser.add_argument('--video_path', type=str, default='/home/poli/kinetics_scaled', help='video path') parser.add_argument('--class_num', type=int, default=400, help='class num') parser.add_argument('--device_id', type=list, default=[0, 1, 2, 3], help='learning rate') parser.add_argument('--resume', action='store_true', help='whether resume') parser.add_argument('--dropout', type=list, default=[0.2, 0.5], help='dropout') parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') parser.add_argument('--saveInter', type=int, default=1, help='how many epoch to save once') parser.add_argument('--TD_rate', type=float, default=0.0, help='propabaility of detachout') parser.add_argument('--img_size', type=int, default=224, help='image size') parser.add_argument('--syn_bn', action='store_true', help='use syn_bn') parser.add_argument('--logName', type=str, default='logs_res18', help='log dir name') parser.add_argument('--train', action='store_true', help='train the model') parser.add_argument('--test', action='store_true', help='test the model') parser.add_argument( '--overlap_rate', type=float, default=0.25, help='the overlap rate of the overlap coherence training scheme') parser.add_argument('--lambdaa', type=float, default=0.0, help='weight of the overlap coherence loss') opt = parser.parse_args() print(opt) torch.cuda.set_device(opt.device_id[0]) # ######################## Module ################################# print('Building model') model = actionModel(opt.class_num, batch_norm=True, dropout=opt.dropout, TD_rate=opt.TD_rate, image_size=opt.img_size, syn_bn=opt.syn_bn, test_scheme=3) print(model) if opt.syn_bn: model = DataParallelWithCallback(model, device_ids=opt.device_id).cuda() else: model = torch.nn.DataParallel(model, device_ids=opt.device_id).cuda() print("Channels: " + str(model.module.channels)) # ########################Optimizer######################### optimizer = torch.optim.SGD([{ 'params': model.module.RNN.parameters(), 'lr': opt.LR[0] }, { 'params': model.module.ShortCut.parameters(), 'lr': opt.LR[0] }, { 'params': model.module.classifier.parameters(), 'lr': opt.LR[1] }], lr=opt.LR[1], weight_decay=opt.weight_decay, momentum=0.9) # ###################### Loss Function #################################### loss_classification_func = nn.NLLLoss(reduce=True) def loss_overlap_coherence_func(pre, cur): loss = nn.MSELoss() return loss(cur, pre.detach()) # ###################### Resume ########################################## resume_epoch = 0 resume_step = 0 max_test_acc = 0 if opt.resume or opt.test: print("loading model") checkpoint = torch.load(opt.model_path + '/' + opt.model_name, map_location={ 'cuda:0': 'cuda:' + str(opt.device_id[0]), 'cuda:1': 'cuda:' + str(opt.device_id[0]), 'cuda:2': 'cuda:' + str(opt.device_id[0]), 'cuda:3': 'cuda:' + str(opt.device_id[0]), 'cuda:4': 'cuda:' + str(opt.device_id[0]), 'cuda:5': 'cuda:' + str(opt.device_id[0]), 'cuda:6': 'cuda:' + str(opt.device_id[0]), 'cuda:7': 'cuda:' + str(opt.device_id[0]) }) model.load_state_dict(checkpoint['model'], strict=True) try: optimizer.load_state_dict(checkpoint['opt'], strict=True) except: pass for group_id, param_group in enumerate(optimizer.param_groups): if group_id == 0: param_group['lr'] = opt.LR[0] elif group_id == 1: param_group['lr'] = opt.LR[0] elif group_id == 2: param_group['lr'] = opt.LR[1] resume_epoch = checkpoint['epoch'] if 'step' in checkpoint: resume_step = checkpoint['step'] + 1 if 'max_acc' in checkpoint: max_test_acc = checkpoint['max_acc'] print('Finish Loading') del checkpoint # ########################################################################### # training and testing model.train() predict_for_mAP = [] label_for_mAP = [] print("START") KineticsLoader = torch.utils.data.DataLoader( Kinetic_train_dataset.Kinetics(video_path=opt.video_path + '/train_frames', frame_num=opt.frame_num, batch_size=opt.batch_size, img_size=opt.img_size, slice_num=opt.slice_num, overlap_rate=opt.overlap_rate), batch_size=1, shuffle=True, num_workers=8) Loader_test = torch.utils.data.DataLoader(Kinetics_test_dataset.Kinetics( video_path=opt.video_path + '/val_frames', img_size=224, space=5, split_num=8, lenn=60, num_class=opt.class_num), batch_size=64, shuffle=True, num_workers=4) tensorboard_writer = SummaryWriter( opt.logName, purge_step=resume_epoch * len(KineticsLoader) * opt.slice_num + (resume_step + resume_step) * opt.slice_num) test = opt.test for epoch in range(resume_epoch, opt.EPOCH): predict_for_mAP = [] label_for_mAP = [] for step, (x, _, overlap_frame_num, action) in enumerate(KineticsLoader): # gives batch data if opt.train: if step + resume_step >= len(KineticsLoader): break x = x[0] action = action[0] overlap_frame_num = overlap_frame_num[0] c = [ Variable( torch.from_numpy( np.zeros( (x.shape[1], model.module.channels[layer + 1], model.module.input_size[layer], model.module.input_size[layer] )))).cuda().float() for layer in range(model.module.RNN_layer) ] for slice in range(x.shape[0]): b_x = Variable(x[slice]).cuda() b_action = Variable(action[slice]).cuda() out, out_beforeMerge, c = model(b_x.float(), c) # rnn output for batch in range(len(out)): predict_for_mAP.append(out[batch].data.cpu().numpy()) label_for_mAP.append( b_action[batch][-1].data.cpu().numpy()) # ###################### overlap coherence loss ####################################################################################### loss_coherence = torch.zeros(1).cuda() # claculate the coherence loss with the previous clip and current clip if slice != 0: for b in range(out.size()[0]): loss_coherence += loss_overlap_coherence_func( old_overlap[b], torch.exp(out_beforeMerge[ b, :overlap_frame_num[slice, b, 0].int()])) loss_coherence = loss_coherence / out.size()[0] # record the previous clips output old_overlap = [] for b in range(out.size()[0]): old_overlap.append( torch.exp( out_beforeMerge[b, -overlap_frame_num[slice, b, 0].int():])) ####################################################################################################################################### loss_classification = loss_classification_func( out, b_action[:, -1].long()) loss = loss_classification + opt.lambdaa * loss_coherence tensorboard_writer.add_scalar( 'train/loss', loss, epoch * len(KineticsLoader) * opt.slice_num + (step + resume_step) * opt.slice_num + slice) loss.backward(retain_graph=False) predict_for_mAP = predict_for_mAP label_for_mAP = label_for_mAP mAPs = mAP(predict_for_mAP, label_for_mAP, 'Lsm') acc = accuracy(predict_for_mAP, label_for_mAP, 'Lsm') tensorboard_writer.add_scalar( 'train/mAP', mAPs, epoch * len(KineticsLoader) * opt.slice_num + (step + resume_step) * opt.slice_num + slice) tensorboard_writer.add_scalar( 'train/acc', acc, epoch * len(KineticsLoader) * opt.slice_num + (step + resume_step) * opt.slice_num + slice) print("Epoch: " + str(epoch) + " step: " + str(step + resume_step) + " Loss: " + str(loss.data.cpu().numpy()) + " Loss_coherence: " + str(loss_coherence.data.cpu().numpy()) + " mAP: " + str(mAPs)[0:7] + " acc: " + str(acc)[0:7]) for p in model.module.parameters(): p.grad.data.clamp_(min=-5, max=5) if step % 2 == 1: optimizer.step() optimizer.zero_grad() predict_for_mAP = [] label_for_mAP = [] # ################################### test ############################### if (step + resume_step) % 700 == 699: test = True if test: print('Start Test') TEST_LOSS = AverageMeter() with torch.no_grad(): model.eval() predict_for_mAP = [] label_for_mAP = [] print("TESTING") for step_test, (x, _, _, action) in tqdm( enumerate(Loader_test)): # gives batch data b_x = Variable(x).cuda() b_action = Variable(action).cuda() c = [ Variable( torch.from_numpy( np.zeros((len(b_x), model.module.channels[layer + 1], model.module.input_size[layer], model.module.input_size[layer] )))).cuda().float() for layer in range(model.module.RNN_layer) ] out, _, _ = model(b_x.float(), c) # rnn output loss = loss_classification_func( out, b_action[:, -1].long()) TEST_LOSS.update(val=loss.data.cpu().numpy()) for batch in range(len(out)): predict_for_mAP.append( out[batch].data.cpu().numpy()) label_for_mAP.append( b_action[batch][-1].data.cpu().numpy()) if step_test % 50 == 0: MAP = mAP(np.array(predict_for_mAP), np.array(label_for_mAP), 'Lsm') acc = accuracy(np.array(predict_for_mAP), np.array(label_for_mAP), 'Lsm') print(" Loss: " + str(TEST_LOSS.avg)[0:5] + ' ' + 'accuracy: ' + str(acc)[0:7]) predict_for_mAP = np.array(predict_for_mAP) label_for_mAP = np.array(label_for_mAP) MAP = mAP(predict_for_mAP, label_for_mAP, 'Lsm') acc = accuracy(predict_for_mAP, label_for_mAP, 'Lsm') print("mAP: " + str(MAP) + ' ' + 'accuracy: ' + str(acc)) if acc > max_test_acc: print('Saving') max_test_acc = acc torch.save( { 'model': model.state_dict(), 'max_acc': max_test_acc, 'epoch': epoch, 'step': 0, 'opt': optimizer.state_dict() }, opt.model_path + '/' + opt.model_name + '_' + str(epoch) + '_' + str(max_test_acc)[0:6]) model.train() test = False predict_for_mAP = [] label_for_mAP = [] if opt.test: exit() if epoch % opt.saveInter == 0: print('Saving') torch.save( { 'model': model.state_dict(), 'max_acc': max_test_acc, 'epoch': epoch, 'step': 0, 'opt': optimizer.state_dict() }, opt.model_path + '/' + opt.model_name + '_' + str(epoch)) resume_step = 0