def _get_sampler(train_set, test_set, val_set, train_sampler, test_sampler, val_sampler, start_epoch): if train_sampler is None: if is_distributed(): train_sampler = DistributedSampler( train_set, num_replicas=get_world_size(), rank=get_global_rank()) train_sampler.set_epoch(start_epoch) else: train_sampler = RandomSampler(train_set, True) else: train_sampler = train_sampler(train_set) if test_sampler is None: if is_distributed(): test_sampler = DistributedSampler( test_set, num_replicas=get_world_size(), rank=get_global_rank()) else: test_sampler = test_sampler(test_set) if val_set is not None: if val_sampler is None and is_distributed(): val_sampler = DistributedSampler(val_set, num_replicas=get_world_size(), rank=get_global_rank()) val_sampler.set_epoch(start_epoch) elif val_sampler is not None: val_sampler = val_sampler(val_set) return train_sampler, test_sampler, val_sampler
def get_ddp_sampler(dataset: Dataset, epoch: int): """ This function will create a DistributedSampler if DDP is initialized, and will just return None if DDP is not initialized. """ if is_initialized(): sampler = DistributedSampler(dataset) sampler.set_epoch(epoch) else: sampler = None return sampler
def build_data_loader( image_path: Union[str, Path], config: dict, uses_absolute_paths: bool, shuffle_off: bool = False, dataset_class: Type[AutoencoderDataset] = AutoencoderDataset ) -> DataLoader: transform_list = [ transforms.Resize((config['image_size'], config['image_size'])), transforms.ToTensor(), transforms.Normalize((0.5, ) * config['input_dim'], (0.5, ) * config['input_dim']) ] transform_list = transforms.Compose(transform_list) dataset = dataset_class( image_path, root=os.path.dirname(image_path) if not uses_absolute_paths else None, transforms=transform_list, loader=resilient_loader, ) sampler = None if get_world_size() > 1: sampler = DistributedSampler(dataset, shuffle=not shuffle_off) sampler.set_epoch(get_rank()) if shuffle_off: shuffle = False else: shuffle = sampler is None loader = DataLoader( dataset, config['batch_size'], shuffle=shuffle, drop_last=True, sampler=sampler, ) return loader
class BalancedBatchSampler(Sampler): def __init__( self, dataset, batch_size, num_replicas, rank, device, mode="atoms", shuffle=True, drop_last=False, force_balancing=False, ): self.dataset = dataset self.batch_size = batch_size self.num_replicas = num_replicas self.rank = rank self.device = device self.mode = mode.lower() self.shuffle = shuffle self.drop_last = drop_last self.balance_batches = self.num_replicas > 1 if self.balance_batches: if ( not hasattr(dataset, "metadata_path") or not dataset.metadata_path.is_file() ): if force_balancing: logging.warning( f"No metadata file found at '{dataset.metadata_path}'. " "BalancedBatchSampler has to load the data to " "determine batch sizes, which incurs " "significant overhead!" ) self.sizes = None else: logging.warning( f"No metadata file found at '{dataset.metadata_path}'. " "Batches will not be balanced, " "which can incur significant overhead!" ) self.balance_batches = False self.sizes = None else: if self.mode == "atoms": self.sizes = np.load(dataset.metadata_path)["natoms"] elif self.mode == "neighbors": self.sizes = np.load(dataset.metadata_path)["neighbors"] else: raise NotImplementedError( f"Unknown load balancing mode: {self.mode}" ) else: self.sizes = None self.single_sampler = DistributedSampler( dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last, ) self.batch_sampler = BatchSampler( self.single_sampler, batch_size, drop_last=drop_last, ) def __len__(self): return len(self.batch_sampler) def set_epoch(self, epoch): self.single_sampler.set_epoch(epoch) def __iter__(self): for batch_idx in self.batch_sampler: if self.balance_batches: if self.sizes is None: # Unfortunately, we need to load the data to know the image sizes data_list = [self.dataset[idx] for idx in batch_idx] if self.mode == "atoms": sizes = [data.num_nodes for data in data_list] elif self.mode == "neighbors": sizes = [ data.edge_index.shape[1] for data in data_list ] else: raise NotImplementedError( f"Unknown load balancing mode: {self.mode}" ) else: sizes = [self.sizes[idx] for idx in batch_idx] idx_sizes = torch.stack( [torch.tensor(batch_idx), torch.tensor(sizes)] ) idx_sizes_all = distutils.all_gather( idx_sizes, device=self.device ) idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu() idx_all = idx_sizes_all[0] sizes_all = idx_sizes_all[1] local_idx_balanced = balanced_partition( sizes_all.numpy(), num_parts=self.num_replicas ) # Since DistributedSampler pads the last batch # this should always have an entry for each replica. yield idx_all[local_idx_balanced[self.rank]] else: yield batch_idx
def main(args): utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # IPython.embed() # IPython.embed() # os.system("sudo chmod -R 777 /home/shuxuang/.cache/") model, criterion, postprocessors = build_model( args) # use the same model as detr paper on coco model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) # dataset_train = build_dataset(image_set='train', args=args) # dataset_val = build_dataset(image_set='val', args=args) # modify the dataset from coco to nvdata # home_dir = os.environ["HOME"] # dataset_train_ = build_nvdataset(dataset_root=[ # os.path.join(os.environ["HOME"],'datasets/annotation_sql_nvidia'), # os.path.join(os.environ["HOME"], 'datasets/frames_nvidia')], # mode='train') # dataset_val = build_nvdataset(dataset_root=[ # os.path.join(os.environ["HOME"],'datasets/test'), # os.path.join(os.environ["HOME"], 'datasets/frames_nvidia')], # mode='test') # indices_50k =np.load(os.path.join(os.environ["HOME"],'datasets/id_1_criterion_Max_SSD_num_labels_50000.npy')) dataset_train = build_nvdataset( dataset_root=[args.dataset_root_sql, args.dataset_root_img], mode='train', camera=args.camera) dataset_val = build_nvdataset( dataset_root=[args.dataset_root_test, args.dataset_root_test], mode='test', camera=args.camera) if args.root_indices is not None: indices_50k = np.load(os.path.join(args.root_indices)) # indices_50k =np.load(os.path.join(os.environ["HOME"],'datasets/id_1_criterion_Max_SSD_num_labels_50000.npy')) dataset_train = Subset(dataset_train, indices_50k) # IPython.embed() print("Train samples: %d" % (len(dataset_train))) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) # if args.dataset_file == "coco_panoptic": # # We also evaluate AP during panoptic training, on original coco DS # coco_val = datasets.coco.build("val", args) # base_ds = get_coco_api_from_dataset(coco_val) # elif args.dataset_file == "nvdata": # coco_val = datasets.coco.build("val", args) # base_ds = get_coco_api_from_dataset(coco_val) # else: # base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: checkpoint = torch.load(args.frozen_weights, map_location='cpu') model_without_ddp.detr.load_state_dict(checkpoint['model']) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 # if args.eval: # test_stats, coco_evaluator = evaluate_nvdata(model, criterion, postprocessors, # data_loader_val, base_ds, device, args.output_dir) # if args.output_dir: # utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") # return # if args.eval: # evaluate(model, dataset_val, postprocessors, device) print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] # extra checkpoint before LR drop and every 100 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 50 == 0: checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) # test_stats, coco_evaluator = evaluate_nvdata( # model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir # ) # log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, # **{f'test_{k}': v for k, v in test_stats.items()}, # 'epoch': epoch, # 'n_parameters': n_parameters} log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") # for evaluation logs # if coco_evaluator is not None: # (output_dir / 'eval').mkdir(exist_ok=True) # if "bbox" in coco_evaluator.coco_eval: # filenames = ['latest.pth'] # if epoch % 50 == 0: # filenames.append(f'{epoch:03}.pth') # for name in filenames: # torch.save(coco_evaluator.coco_eval["bbox"].eval, # output_dir / "eval" / name) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def train(rank, a, h): if h.num_gpus > 1: init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) torch.cuda.manual_seed(h.seed) torch.cuda.set_device(rank) device = torch.device('cuda:{:d}'.format(rank)) generator = Generator(h).to(device) mpd = MultiPeriodDiscriminator().to(device) msd = MultiScaleDiscriminator().to(device) if rank == 0: print(generator) os.makedirs(a.checkpoint_path, exist_ok=True) print("checkpoints directory : ", a.checkpoint_path) if os.path.isdir(a.checkpoint_path): cp_g = scan_checkpoint(a.checkpoint_path, 'g_') cp_do = scan_checkpoint(a.checkpoint_path, 'do_') steps = 0 if cp_g is None or cp_do is None: state_dict_do = None last_epoch = -1 else: state_dict_g = load_checkpoint(cp_g, device) state_dict_do = load_checkpoint(cp_do, device) generator.load_state_dict(state_dict_g['generator']) mpd.load_state_dict(state_dict_do['mpd']) msd.load_state_dict(state_dict_do['msd']) steps = state_dict_do['steps'] + 1 last_epoch = state_dict_do['epoch'] if h.num_gpus > 1: generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) optim_d.load_state_dict(state_dict_do['optim_d']) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) training_filelist, validation_filelist = get_dataset_filelist(a) trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, sampler=train_sampler, batch_size=h.batch_size, pin_memory=True, drop_last=True) if rank == 0: validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) validation_loader = DataLoader(validset, num_workers=1, shuffle=False, sampler=None, batch_size=1, pin_memory=True, drop_last=True) sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) generator.train() mpd.train() msd.train() for epoch in range(max(0, last_epoch), a.training_epochs): if rank == 0: start = time.time() print("Epoch: {}".format(epoch + 1)) if h.num_gpus > 1: train_sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): if rank == 0: start_b = time.time() x, y, _, y_mel = batch x = torch.autograd.Variable(x.to(device, non_blocking=True)) y = torch.autograd.Variable(y.to(device, non_blocking=True)) y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) y = y.unsqueeze(1) y_g_hat = generator(x) y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax_for_loss) optim_d.zero_grad() # MPD y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( y_df_hat_r, y_df_hat_g) # MSD y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( y_ds_hat_r, y_ds_hat_g) loss_disc_all = loss_disc_s + loss_disc_f loss_disc_all.backward() optim_d.step() # Generator optim_g.zero_grad() # L1 Mel-Spectrogram Loss loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel loss_gen_all.backward() optim_g.step() if rank == 0: # STDOUT logging if steps % a.stdout_interval == 0: with torch.no_grad(): mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() print( 'Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}' .format(steps, loss_gen_all, mel_error, time.time() - start_b)) # checkpointing if steps % a.checkpoint_interval == 0 and steps != 0: checkpoint_path = "{}/g_{:08d}".format( a.checkpoint_path, steps) save_checkpoint( checkpoint_path, { 'generator': (generator.module if h.num_gpus > 1 else generator).state_dict() }) checkpoint_path = "{}/do_{:08d}".format( a.checkpoint_path, steps) save_checkpoint( checkpoint_path, { 'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), 'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch }) # Tensorboard summary logging if steps % a.summary_interval == 0: sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) sw.add_scalar("training/mel_spec_error", mel_error, steps) # Validation if steps % a.validation_interval == 0: # and steps != 0: generator.eval() torch.cuda.empty_cache() val_err_tot = 0 with torch.no_grad(): for j, batch in enumerate(validation_loader): x, y, _, y_mel = batch y_g_hat = generator(x.to(device)) y_mel = torch.autograd.Variable( y_mel.to(device, non_blocking=True)) y_g_hat_mel = mel_spectrogram( y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax_for_loss) val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() if j <= 4: if steps == 0: sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) y_hat_spec = mel_spectrogram( y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) sw.add_figure( 'generated/y_hat_spec_{}'.format(j), plot_spectrogram( y_hat_spec.squeeze(0).cpu().numpy()), steps) val_err = val_err_tot / (j + 1) sw.add_scalar("validation/mel_spec_error", val_err, steps) generator.train() steps += 1 scheduler_g.step() scheduler_d.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start)))
def main(args): utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) wandb.init(project="qpic-project", entity="sangbaeklee", group="experiment_qpic") wandb.config = { "learning_rate": args.lr, "epochs": args.epochs, "batch_size": args.batch_size, } if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) wandb.watch(model) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(image_set='train', args=args) dataset_val = build_dataset(image_set='val', args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) if not args.hoi: if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) base_ds = get_coco_api_from_dataset(coco_val) else: base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: checkpoint = torch.load(args.frozen_weights, map_location='cpu') model_without_ddp.detr.load_state_dict(checkpoint['model']) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 elif args.pretrained: checkpoint = torch.load(args.pretrained, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model'], strict=False) if args.eval: if args.hoi: test_stats = evaluate_hoi(args.dataset_file, model, postprocessors, data_loader_val, args.subject_category_id, device) return else: test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) if args.output_dir: utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] # extra checkpoint before LR drop and every 100 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) if args.hoi: test_stats = evaluate_hoi(args.dataset_file, model, postprocessors, data_loader_val, args.subject_category_id, device) coco_evaluator = None else: test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } #import pdb; pdb.set_trace() if args.dataset_file == 'hico': wandb.log({ "loss": train_stats['loss'], "mAP": test_stats['mAP'], "mAP rare": test_stats['mAP rare'], "mAP non-rare": test_stats['mAP non-rare'], "mean max recall": test_stats['mean max recall'] }) elif args.dataset_file == 'vcoco': wandb.log({ "mAP_all": test_stats['mAP_all'], "mAP_thesis": test_stats['mAP_thesis'], "AP_hold_obj": test_stats['AP_hold_obj'], "AP_stand": test_stats['AP_stand'], "AP_sit_instr": test_stats['AP_sit_instr'], "AP_ride_instr": test_stats['AP_ride_instr'], "AP_walk": test_stats['AP_walk'], "AP_look_obj": test_stats['AP_look_obj'], "AP_hit_instr": test_stats['AP_hit_instr'], "AP_hit_obj": test_stats['AP_hit_obj'], "AP_eat_obj": test_stats['AP_eat_obj'], "AP_eat_instr": test_stats['AP_eat_instr'], "AP_jump_instr": test_stats['AP_jump_instr'], "AP_lay_instr": test_stats['AP_lay_instr'], "AP_talk_on_phone_instr": test_stats['AP_talk_on_phone_instr'], "AP_carry_obj": test_stats['AP_carry_obj'], "AP_throw_obj": test_stats['AP_throw_obj'], "AP_catch_obj": test_stats['AP_catch_obj'], "AP_cut_instr": test_stats['AP_cut_instr'], "AP_cut_obj": test_stats['AP_cut_obj'], "AP_run": test_stats['AP_run'], "AP_work_on_computer_instr": test_stats['AP_work_on_computer_instr'], "AP_ski_instr": test_stats['AP_ski_instr'], "AP_surf_instr": test_stats['AP_surf_instr'], "AP_skateboard_instr": test_stats['AP_skateboard_instr'], "AP_smile": test_stats['AP_smile'], "AP_drink_instr": test_stats['AP_drink_instr'], "AP_kick_obj": test_stats['AP_kick_obj'], "AP_point_instr": test_stats['AP_point_instr'], "AP_read_obj": test_stats['AP_read_obj'], "AP_snowboard_instr": test_stats['AP_snowboard_instr'],\ "loss" : train_stats['loss'] }) else: continue if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") # for evaluation logs if coco_evaluator is not None: (output_dir / 'eval').mkdir(exist_ok=True) if "bbox" in coco_evaluator.coco_eval: filenames = ['latest.pth'] if epoch % 50 == 0: filenames.append(f'{epoch:03}.pth') for name in filenames: torch.save(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): bz = args.batch_size lr = args.lr if args.cuda: if torch.cuda.device_count() >= 1: utils.init_distributed_mode(args) device = torch.device(args.device) else: device = torch.device('cpu') # fix the seed for reproducibility if args.cuda: seed = args.seed + utils.get_rank() else: seed = args.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # set up model model, criterion, postprocessors = build_model(args) model_without_ddp = model if args.cuda and args.distributed: if args.mp: model = torch.nn.parallel.DistributedDataParallel(model) else: model = torch.nn.parallel.DistributedDataParallel( model.to(args.gpu), device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module elif args.cuda: model = model.to(device) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) # set up model training param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "joiner" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "joiner" in n and p.requires_grad ], "lr": args.lr_joiner, }, ] # datasets build dataset_train = build_dataset(mode="training", args=args) dataset_test = build_dataset(mode="testing", args=args) if args.cuda and args.distributed: sampler_train = DistributedSampler(dataset_train, shuffle=False) sampler_test = DistributedSampler(dataset_test, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_test = torch.utils.data.SequentialSampler(dataset_test) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_test = DataLoader(dataset_test, 1, sampler=sampler_test, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) # output and checkpoints directory checkpoint_dir = args.output_dir if not os.path.exists(checkpoint_dir): try: os.makedirs(checkpoint_dir) except OSError: pass if args.resume: checkpoint = Path(args.resume) assert checkpoint.exists() checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 print("Start Training") start_time = time.time() optimizer.zero_grad() for epoch in range(args.start_epoch, args.epochs): if args.cuda and args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(epoch, args.clip_max_norm, model, criterion, data_loader_train, optimizer, lr_scheduler, device) if args.output_dir: checkpoint_dir = Path(checkpoint_dir) checkpoint_paths = [checkpoint_dir / 'checkpoint.pth'] # extra checkpoint before LR drop and every 100 epochs if (epoch + 1) % args.lr_drop == 0 or ( epoch + 1) % args.save_checkpoint_every == 0: checkpoint_paths.append(checkpoint_dir / f'checkpoint{epoch:05}.pth') for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) if (epoch + 1) % args.eval_interval == 0: # evaluation test_stats = evaluate(epoch, model, criterion, postprocessors, data_loader_test, args.output_dir, args.dataset, device) log_stats = { **{'train_' + str(k): v for k, v in train_stats.items()}, **{'test_' + str(k): v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (checkpoint_dir / 'log.json').open("a") as f: f.write(json.dumps(log_stats) + "\n") lr_scheduler.step() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) # no validation ground truth for ytvos dataset dataset_train = build_dataset(image_set='train', args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) output_dir = Path(args.output_dir) # load coco pretrained weight checkpoint = torch.load(args.pretrained_weights, map_location='cpu')['model'] del checkpoint["vistr.class_embed.weight"] del checkpoint["vistr.class_embed.bias"] del checkpoint["vistr.query_embed.weight"] model.module.load_state_dict(checkpoint, strict=False) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] # extra checkpoint before LR drop and every epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 1 == 0: checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
class EnergyTrainer(BaseTrainer): """ Trainer class for the Initial Structure to Relaxed Energy (IS2RE) task. .. note:: Examples of configurations for task, model, dataset and optimizer can be found in `configs/ocp_is2re <https://github.com/Open-Catalyst-Project/baselines/tree/master/configs/ocp_is2re/>`_. Args: task (dict): Task configuration. model (dict): Model configuration. dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset. optimizer (dict): Optimizer configuration. identifier (str): Experiment identifier that is appended to log directory. run_dir (str, optional): Path to the run directory where logs are to be saved. (default: :obj:`None`) is_debug (bool, optional): Run in debug mode. (default: :obj:`False`) is_vis (bool, optional): Run in debug mode. (default: :obj:`False`) print_every (int, optional): Frequency of printing logs. (default: :obj:`100`) seed (int, optional): Random number seed. (default: :obj:`None`) logger (str, optional): Type of logger to be used. (default: :obj:`tensorboard`) local_rank (int, optional): Local rank of the process, only applicable for distributed training. (default: :obj:`0`) amp (bool, optional): Run using automatic mixed precision. (default: :obj:`False`) """ def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, ): super().__init__( task=task, model=model, dataset=dataset, optimizer=optimizer, identifier=identifier, run_dir=run_dir, is_debug=is_debug, is_vis=is_vis, print_every=print_every, seed=seed, logger=logger, local_rank=local_rank, amp=amp, cpu=cpu, name="is2re", ) def load_task(self): assert (self.config["task"]["dataset"] == "single_point_lmdb" ), "EnergyTrainer requires single_point_lmdb dataset" print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater( 1 if not self.cpu else 0, self.config["model_attributes"].get("otf_graph", False), ) self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=True, ) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.train_sampler, ) self.val_loader = self.test_loader = None self.val_sampler = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_sampler = DistributedSampler( self.val_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.val_sampler, ) if "test_dataset" in self.config: self.test_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["test_dataset"]) self.test_sampler = DistributedSampler( self.test_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.test_loader = DataLoader( self.test_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.test_sampler, ) self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", False): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: raise NotImplementedError def predict(self, loader, results_file=None, disable_tqdm=False): if distutils.is_master() and not disable_tqdm: print("### Predicting on test.") assert isinstance(loader, torch.utils.data.dataloader.DataLoader) rank = distutils.get_rank() self.model.eval() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) predictions = {"id": [], "energy": []} for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) predictions["id"].extend([str(i) for i in batch[0].sid.tolist()]) predictions["energy"].extend(out["energy"].tolist()) self.save_results(predictions, results_file, keys=["energy"]) return predictions def train(self): self.best_val_mae = 1e9 start_epoch = self.start_step // len(self.train_loader) for epoch in range(start_epoch, self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch) self.model.train() skip_steps = 0 if epoch == start_epoch and start_epoch > 0: skip_steps = start_epoch % len(self.train_loader) train_loader_iter = iter(self.train_loader) for i in range(skip_steps, len(self.train_loader)): batch = next(train_loader_iter) # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, metrics={}, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if (i % self.config["cmd"]["print_every"] == 0 and distutils.is_master()): log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) if self.update_lr_on_step: self.scheduler.step() if not self.update_lr_on_step: self.scheduler.step() torch.cuda.empty_cache() if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric[self.name]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric[ self.name]]["metric"] current_step = (epoch + 1) * len(self.train_loader) self.save(epoch + 1, current_step, val_metrics) if self.test_loader is not None: self.predict( self.test_loader, results_file="predictions", disable_tqdm=False, ) else: current_step = (epoch + 1) * len(self.train_loader) self.save(epoch + 1, current_step, self.metrics) self.train_dataset.close_db() if "val_dataset" in self.config: self.val_dataset.close_db() if "test_dataset" in self.config: self.test_dataset.close_db() def _forward(self, batch_list): output = self.model(batch_list) if output.shape[-1] == 1: output = output.view(-1) return { "energy": output, } def _compute_loss(self, out, batch_list): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): target_normed = self.normalizers["target"].norm(energy_target) else: target_normed = energy_target loss = self.criterion(out["energy"], target_normed) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): out["energy"] = self.normalizers["target"].denorm(out["energy"]) metrics = evaluator.eval( out, {"energy": energy_target}, prev_metrics=metrics, ) return metrics
def main(args): # utils.init_distributed_mode(args) if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("number of params:", n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(image_set="train", args=args) dataset_val = build_dataset(image_set="val", args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader( dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers, ) data_loader_val = DataLoader( dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers, ) if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) base_ds = get_coco_api_from_dataset(coco_val) else: base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: checkpoint = torch.load(args.frozen_weights, map_location="cpu") model_without_ddp.detr.load_state_dict(checkpoint["model"]) if args.resume: if args.resume.startswith("https"): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location="cpu", check_hash=True) else: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) if (not args.eval and "optimizer" in checkpoint and "lr_scheduler" in checkpoint and "epoch" in checkpoint): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if args.eval: test_stats, coco_evaluator = evaluate( model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir, ) if args.output_dir: with PathManager.open(os.path.join(args.output_dir, "eval.pth"), "wb") as f: utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, f) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm, ) lr_scheduler.step() if args.output_dir: checkpoint_paths = [ ] # os.path.join(args.output_dir, 'checkpoint.pth')] # extra checkpoint before LR drop and every 10 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 10 == 0: checkpoint_paths.append( os.path.join(args.output_dir, f"checkpoint{epoch:04}.pth")) for checkpoint_path in checkpoint_paths: with PathManager.open(checkpoint_path, "wb") as f: if args.gpu == 0 and args.machine_rank == 0: utils.save_on_master( { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "epoch": epoch, "args": args, }, f, ) test_stats, coco_evaluator = evaluate( model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir, ) log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, **{f"test_{k}": v for k, v in test_stats.items()}, "epoch": epoch, "n_parameters": n_parameters, } if args.output_dir and utils.is_main_process(): with PathManager.open(os.path.join(args.output_dir, "log.txt"), "w") as f: f.write(json.dumps(log_stats) + "\n") # for evaluation logs if coco_evaluator is not None: PathManager.mkdirs(os.path.join(args.output_dir, "eval")) if "bbox" in coco_evaluator.coco_eval: filenames = ["latest.pth"] if epoch % 50 == 0: filenames.append(f"{epoch:03}.pth") for name in filenames: with PathManager.open( os.path.join(args.output_dir, "eval", name), "wb") as f: torch.save(coco_evaluator.coco_eval["bbox"].eval, f) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time {}".format(total_time_str))
def train(rank, args, hp, hp_str): # if hp.train.num_gpus > 1: # init_process_group(backend=hp.dist.dist_backend, init_method=hp.dist.dist_url, # world_size=hp.dist.world_size * hp.train.num_gpus, rank=rank) torch.cuda.manual_seed(hp.train.seed) device = torch.device('cuda:{:d}'.format(rank)) generator = Generator(hp.model.in_channels, hp.model.out_channels).to(device) specd = SpecDiscriminator().to(device) msd = MultiScaleDiscriminator().to(device) stft_loss = MultiResolutionSTFTLoss() if rank == 0: print(generator) os.makedirs(hp.logs.chkpt_dir, exist_ok=True) print("checkpoints directory : ", hp.logs.chkpt_dir) if os.path.isdir(hp.logs.chkpt_dir): cp_g = scan_checkpoint(hp.logs.chkpt_dir, 'g_') cp_do = scan_checkpoint(hp.logs.chkpt_dir, 'do_') steps = 0 if cp_g is None or cp_do is None: state_dict_do = None last_epoch = -1 else: state_dict_g = load_checkpoint(cp_g, device) state_dict_do = load_checkpoint(cp_do, device) generator.load_state_dict(state_dict_g['generator']) specd.load_state_dict(state_dict_do['specd']) msd.load_state_dict(state_dict_do['msd']) steps = state_dict_do['steps'] + 1 last_epoch = state_dict_do['epoch'] if hp.train.num_gpus > 1: generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) specd = DistributedDataParallel(specd, device_ids=[rank]).to(device) msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) optim_g = torch.optim.AdamW( generator.parameters(), hp.train.adamG.lr, betas=[hp.train.adamG.beta1, hp.train.adamG.beta2]) optim_d = torch.optim.AdamW( itertools.chain(msd.parameters(), specd.parameters()), hp.train.adamD.lr, betas=[hp.train.adamD.beta1, hp.train.adamD.beta2]) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) optim_d.load_state_dict(state_dict_do['optim_d']) # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hp.train.adam.lr_decay, last_epoch=last_epoch) # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hp.train.adam.lr_decay, last_epoch=last_epoch) training_filelist, validation_filelist = get_dataset_filelist(args) trainset = MelDataset(training_filelist, hp.data.input_wavs, hp.data.output_wavs, hp.audio.segment_length, hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.hop_length, hp.audio.win_length, hp.audio.sampling_rate, hp.audio.mel_fmin, hp.audio.mel_fmax, n_cache_reuse=0, shuffle=False if hp.train.num_gpus > 1 else True, fmax_loss=None, device=device) train_sampler = DistributedSampler( trainset) if hp.train.num_gpus > 1 else None train_loader = DataLoader(trainset, num_workers=hp.train.num_workers, shuffle=False, sampler=train_sampler, batch_size=hp.train.batch_size, pin_memory=True, drop_last=True) if rank == 0: validset = MelDataset(validation_filelist, hp.data.input_wavs, hp.data.output_wavs, hp.audio.segment_length, hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.hop_length, hp.audio.win_length, hp.audio.sampling_rate, hp.audio.mel_fmin, hp.audio.mel_fmax, split=False, shuffle=False, n_cache_reuse=0, fmax_loss=None, device=device) validation_loader = DataLoader(validset, num_workers=1, shuffle=False, sampler=None, batch_size=1, pin_memory=True, drop_last=True) sw = SummaryWriter(os.path.join(hp.logs.chkpt_dir, 'logs')) generator.train() specd.train() msd.train() with_postnet = False for epoch in range(max(0, last_epoch), args.training_epochs): if rank == 0: start = time.time() print("Epoch: {}".format(epoch + 1)) if hp.train.num_gpus > 1: train_sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): if rank == 0: start_b = time.time() if steps > hp.train.postnet_start_steps: with_postnet = True x, y, file, _, y_mel_loss = batch x = torch.autograd.Variable(x.to(device, non_blocking=True)) y = torch.autograd.Variable(y.to(device, non_blocking=True)) y_mel_loss = torch.autograd.Variable( y_mel_loss.to(device, non_blocking=True)) # y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) x = x.unsqueeze(1) y = y.unsqueeze(1) before_y_g_hat, y_g_hat = generator(x, with_postnet) if y_g_hat is not None: y_g_hat_mel = mel_spectrogram( y_g_hat.squeeze(1), hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.sampling_rate, hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin, None) if steps > hp.train.discriminator_train_start_steps: for _ in range(hp.train.rep_discriminator): optim_d.zero_grad() # SpecD y_df_hat_r, y_df_hat_g, _, _ = specd( y_mel_loss, y_g_hat_mel.detach()) loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( y_df_hat_r, y_df_hat_g) # MSD y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( y_ds_hat_r, y_ds_hat_g) loss_disc_all = loss_disc_s + loss_disc_f loss_disc_all.backward() optim_d.step() before_y_g_hat_mel = mel_spectrogram( before_y_g_hat.squeeze(1), hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.sampling_rate, hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin, None) # Generator optim_g.zero_grad() # L1 Mel-Spectrogram Loss # before_loss_mel = F.l1_loss(y_mel_loss, before_y_g_hat_mel) sc_loss, mag_loss = stft_loss( before_y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1)) before_loss_mel = sc_loss + mag_loss # L1 Sample Loss before_loss_sample = F.l1_loss(y, before_y_g_hat) loss_gen_all = before_loss_mel + before_loss_sample if y_g_hat is not None: # L1 Mel-Spectrogram Loss # loss_mel = F.l1_loss(y_mel_loss, y_g_hat_mel) sc_loss_, mag_loss_ = stft_loss( y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1)) loss_mel = sc_loss_ + mag_loss_ # L1 Sample Loss loss_sample = F.l1_loss(y, y_g_hat) loss_gen_all += loss_mel + loss_sample if steps > hp.train.discriminator_train_start_steps: y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = specd( y_mel_loss, y_g_hat_mel) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) loss_gen_all += hp.model.lambda_adv * ( loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f) loss_gen_all.backward() optim_g.step() if rank == 0: # STDOUT logging if steps % args.stdout_interval == 0: with torch.no_grad(): mel_error = F.l1_loss(y_mel_loss, before_y_g_hat_mel).item() sample_error = F.l1_loss(y, before_y_g_hat) print( 'Steps : {:d}, Gen Loss Total : {:4.3f}, Sample Error: {:4.3f}, ' 'Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.format( steps, loss_gen_all, sample_error, mel_error, time.time() - start_b)) # checkpointing if steps % hp.logs.save_interval == 0 and steps != 0: checkpoint_path = "{}/g_{:08d}".format( hp.logs.chkpt_dir, steps) save_checkpoint( checkpoint_path, { 'generator': (generator.module if hp.train.num_gpus > 1 else generator).state_dict() }) checkpoint_path = "{}/do_{:08d}".format( hp.logs.chkpt_dir, steps) save_checkpoint( checkpoint_path, { 'specd': (specd.module if hp.train.num_gpus > 1 else specd).state_dict(), 'msd': (msd.module if hp.train.num_gpus > 1 else msd).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch, 'hp_str': hp_str }) # Tensorboard summary logging if steps % hp.logs.summary_interval == 0: sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) sw.add_scalar("training/mel_spec_error", mel_error, steps) # Validation if steps % hp.logs.validation_interval == 0: # and steps != 0: generator.eval() torch.cuda.empty_cache() val_err_tot = 0 with torch.no_grad(): for j, batch in enumerate(validation_loader): x, y, file, y_mel, y_mel_loss = batch x = x.unsqueeze(1) y = y.unsqueeze(1).to(device) before_y_g_hat, y_g_hat = generator(x.to(device)) y_mel_loss = torch.autograd.Variable( y_mel_loss.to(device, non_blocking=True)) y_g_hat_mel = mel_spectrogram( before_y_g_hat.squeeze(1), hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.sampling_rate, hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin, None) val_err_tot += F.l1_loss(y_mel_loss, y_g_hat_mel).item() val_err_tot += F.l1_loss(y, before_y_g_hat).item() if y_g_hat is not None: val_err_tot += F.l1_loss(y, y_g_hat).item() if j <= 4: if steps == 0: sw.add_audio('gt_noise/y_{}'.format(j), x[0], steps, hp.audio.sampling_rate) sw.add_audio('gt_clean/y_{}'.format(j), y[0], steps, hp.audio.sampling_rate) sw.add_figure( 'gt/y_spec_clean_{}'.format(j), plot_spectrogram(y_mel[0]), steps) sw.add_audio('generated/y_hat_{}'.format(j), before_y_g_hat[0], steps, hp.audio.sampling_rate) if y_g_hat is not None: sw.add_audio( 'generated/y_hat_after_{}'.format(j), y_g_hat[0], steps, hp.audio.sampling_rate) y_hat_spec = mel_spectrogram( before_y_g_hat.squeeze(1), hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.sampling_rate, hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin, None) sw.add_figure( 'generated/y_hat_spec_{}'.format(j), plot_spectrogram( y_hat_spec.squeeze(0).cpu().numpy()), steps) val_err = val_err_tot / (j + 1) sw.add_scalar("validation/mel_spec_error", val_err, steps) generator.train() steps += 1 # scheduler_g.step() # scheduler_d.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start)))
def main(args): utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) # align with DETR format args.dataset_file = 'ImageNet' args.masks = None # freeze cnn weights args.lr_backbone = 0 if args.fre_cnn else args.lr print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(image_set='train', args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.updetr_collate_fn, num_workers=args.num_workers) print(len(data_loader_train) * args.epochs) output_dir = Path(args.output_dir) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) if lr_scheduler.step_size != args.lr_drop: lr_scheduler.step_size = args.lr_drop args.start_epoch = checkpoint['epoch'] + 1 print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] # extra checkpoint before LR drop and every 20 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 20 == 0: checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def run_process(): '''Run process This is what is actually run on each process. ''' # Get distributed parameters rank = dist.get_rank() local_rank = int(os.environ["LOCAL_RANK"]) world_size = dist.get_world_size() # Initialize data_loader context_size = 512 batch_size = 32 corpus_length = 1024 vocab_size = 2**8 dataset = RandomCorpus(corpus_length, context_size, vocab_size) sampler = DistributedSampler(dataset, shuffle=True, drop_last=False) data_loader = DataLoader( dataset=dataset, batch_size=batch_size, sampler=sampler, ) # Initialize model model = GPT(vocab_size, context_size, verbose=True) device = torch.device(f"cuda:{local_rank}") model.to(device) # Prepare for distributed data parallelism model = DistributedDataParallel(model, device_ids=[rank], output_device=rank) # The learning rate is adapted for the total batch_size in tokens learning_rate = 6e-4 * (batch_size * world_size * context_size / 5e5) # ZeroRedundancyOptimizer reduces the memory footprint of the Optimizer opt = ZeroRedundancyOptimizer( model.parameters(), optimizer_class=optim.Adam, lr=learning_rate, ) loss_func = nn.CrossEntropyLoss() # Initialize logger instance to see performance writer = BenchmarkWriter() # Actual training global_step = 0 n_epochs = 10 for epoch in range(n_epochs): model.train() sampler.set_epoch(epoch) # for correct shuffling for sequence, in data_loader: opt.zero_grad() # Shift so that prediction is next token for each token sequence = sequence.to(device) logits = model(sequence[..., :-1].contiguous()) target = sequence[..., 1:].contiguous() # Flatten the tokens when calculating loss loss = loss_func( logits.flatten(end_dim=-2), target.flatten(), ) loss.backward() opt.step() # This will also log the wall time if rank == 0: global_step += batch_size * world_size writer.add_scalar("Loss", loss.item(), global_step=global_step) if rank == 0: print("Epoch:", epoch) if rank == 0: writer.benchmark_results(burn_in_steps=2 * corpus_length, step_unit="seq") writer.close() return model
def train(rank, a, h): if h.num_gpus > 1: init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) torch.cuda.manual_seed(h.seed) device = torch.device('cuda:{:d}'.format(rank)) generator = Generator(h).to(device) mpd = MultiPeriodDiscriminator( h["discriminator_periods"] if "discriminator_periods" in h.keys() else None).to(device) msd = MultiScaleDiscriminator().to(device) if rank == 0: print(generator) os.makedirs(a.checkpoint_path, exist_ok=True) print("checkpoints directory : ", a.checkpoint_path) if os.path.isdir(a.checkpoint_path): cp_g = scan_checkpoint(a.checkpoint_path, 'g_') cp_do = scan_checkpoint(a.checkpoint_path, 'do_') steps = 0 if cp_g is not None: state_dict_g = load_checkpoint(cp_g, device) gsd = generator.state_dict() gsd.update({ k: v for k, v in state_dict_g['generator'].items() if k in gsd and state_dict_g['generator'][k].shape == gsd[k].shape }) missing_keys = { k: v for k, v in state_dict_g['generator'].items() if not (k in gsd and state_dict_g['generator'][k].shape == gsd[k].shape) }.keys() generator.load_state_dict(gsd) del gsd, state_dict_g if cp_do is None or len(missing_keys) or a.from_zero: state_dict_do = None last_epoch = -1 else: state_dict_do = load_checkpoint(cp_do, device) mpd.load_state_dict(state_dict_do['mpd']) del state_dict_do['mpd'] msd.load_state_dict(state_dict_do['msd']) del state_dict_do['msd'] steps = state_dict_do['steps'] + 1 last_epoch = state_dict_do['epoch'] if h.num_gpus > 1: generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) optim_d.load_state_dict(state_dict_do['optim_d']) del state_dict_do scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) training_filelist, validation_filelist = get_dataset_filelist( a, h.segment_size, h.sampling_rate) trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, trim_non_voiced=a.trim_non_voiced) STFT = STFT_Class(h.sampling_rate, h.num_mels, h.n_fft, h.win_size, h.hop_size, h.fmin, h.fmax) train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, sampler=train_sampler, batch_size=h.batch_size, pin_memory=True, drop_last=True) assert len(train_loader), 'No audio files in dataset!' if rank == 0: validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, trim_non_voiced=a.trim_non_voiced) validation_loader = DataLoader(validset, num_workers=h.num_workers, shuffle=False, sampler=None, batch_size=1, pin_memory=True, drop_last=True) sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'), max_queue=10000) sw.logged_gt_plots = False if h.num_gpus > 1: import gc gc.collect() torch.cuda.empty_cache() generator.train() mpd.train() msd.train() for epoch in range(max(0, last_epoch), a.training_epochs): if rank == 0: start = time.time() print("Epoch: {}".format(epoch + 1)) if h.num_gpus > 1: train_sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): if rank == 0: start_b = time.time() x, y, _, y_mel = batch x = torch.autograd.Variable(x.to(device, non_blocking=True)) y = torch.autograd.Variable(y.to(device, non_blocking=True)) y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) y = y.unsqueeze(1) y_g_hat = generator(x) y_g_hat_mel = STFT.get_mel(y_g_hat.squeeze(1)) optim_d.zero_grad() # MPD y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( y_df_hat_r, y_df_hat_g) # MSD y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( y_ds_hat_r, y_ds_hat_g) loss_disc_all = loss_disc_s + loss_disc_f loss_disc_all.backward() optim_d.step() # Generator optim_g.zero_grad() # L1 Mel-Spectrogram Loss loss_mel = F.l1_loss(y_mel, y_g_hat_mel) y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel * 45 loss_gen_all.backward() optim_g.step() if rank == 0: torch.set_grad_enabled(False) # STDOUT logging if steps % a.stdout_interval == 0: print( 'Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}' .format(steps, loss_gen_all, loss_mel.item(), time.time() - start_b)) # checkpointing if steps % a.checkpoint_interval == 0 and steps != 0: checkpoint_path = "{}/g_{:08d}".format( a.checkpoint_path, steps) save_checkpoint( checkpoint_path, { 'generator': (generator.module if h.num_gpus > 1 else generator).state_dict() }) checkpoint_path = "{}/do_{:08d}".format( a.checkpoint_path, steps) save_checkpoint( checkpoint_path, { 'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), 'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch }) del_old_checkpoints(a.checkpoint_path, 'g_', a.n_models_to_keep) del_old_checkpoints(a.checkpoint_path, 'do_', a.n_models_to_keep) # Tensorboard summary logging if steps % a.summary_interval == 0: sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) sw.add_scalar("training/mel_spec_error", loss_mel.item(), steps) # Validation if steps % a.validation_interval == 0: # and steps != 0: print("Validating...") n_audios_to_plot = 6 generator.eval() torch.cuda.empty_cache() val_err_tot = 0 for j, batch in tqdm(enumerate(validation_loader), total=len(validation_loader)): x, y, _, y_mel = batch y_g_hat = generator(x.to(device)) y_hat_spec = STFT.get_mel(y_g_hat.squeeze(1)) val_err_tot += F.l1_loss(y_mel, y_hat_spec.to(y_mel)).item() if j < n_audios_to_plot and not sw.logged_gt_plots: sw.add_audio(f'gt/y_{j}', y[0], steps, h.sampling_rate) sw.add_figure(f'spec_{j:02}/gt_spec', plot_spectrogram(y_mel[0]), steps) if j < n_audios_to_plot: sw.add_audio(f'generated/y_hat_{j}', y_g_hat[0], steps, h.sampling_rate) sw.add_figure( f'spec_{j:02}/pred_spec', plot_spectrogram( y_hat_spec.squeeze(0).cpu().numpy()), steps) if j > 64: # I am NOT patient enough to complete an entire validation cycle with over 1536 files. break sw.logged_gt_plots = True val_err = val_err_tot / (j + 1) sw.add_scalar("validation/mel_spec_error", val_err, steps) generator.train() print(f"Done. Val_loss = {val_err}") torch.set_grad_enabled(True) steps += 1 scheduler_g.step() scheduler_d.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start)))
def main(args): utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("number of params:", n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(image_set="train", args=args) dataset_val = build_dataset(image_set="val", args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader( dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers, ) data_loader_val = DataLoader( dataset_val, args.batch_size if args.batch_size < 4 else 4, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers, ) if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) base_ds = get_coco_api_from_dataset(coco_val) elif args.dataset_file in ["cmdd", "cmdc", "wider"]: base_ds = None elif args.dataset_file == "MOT17": base_ds = dataset_val else: base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: checkpoint = torch.load(args.frozen_weights, map_location="cpu") model_without_ddp.detr.load_state_dict(checkpoint["model"]) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith("https"): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location="cpu", check_hash=True) else: checkpoint = torch.load(args.resume, map_location="cpu") # NOTE: this is Bruno's hack to load stuff in model_dict = model_without_ddp.state_dict() pretrained_dict = checkpoint["model"] # hack for adding query stuff if ("query_embed.query_embed.weight" in model_dict.keys() and "query_embed.weight" in pretrained_dict.keys()): pretrained_dict[ "query_embed.query_embed.weight"] = pretrained_dict[ "query_embed.weight"] # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # if finetuning skip the linear stuff if args.finetune: pretrained_dict = { k: v for k, v in pretrained_dict.items() if k not in ["class_embed.weight", "class_embed.bias"] } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load new state dict model_without_ddp.load_state_dict(model_dict) if (not args.eval and not args.load_model_only and "optimizer" in checkpoint and "lr_scheduler" in checkpoint and "epoch" in checkpoint): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if args.eval: if args.test and args.dataset_file == "wider": if args.resume: s = args.resume.split("/")[:-1] output_dir = "/" + os.path.join(*s) else: output_dir = args.output_dir print("SAVING TEST WIDER TO ", output_dir) test_wider( model, criterion, postprocessors, dataset_val, data_loader_val, device, output_dir, ) return test_stats, coco_evaluator = evaluate( model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir, dset_file=args.dataset_file, ) if args.output_dir and coco_evaluator is not None: utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm, ) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / "checkpoint.pth"] # extra checkpoint before LR drop and every 100 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth") for checkpoint_path in checkpoint_paths: utils.save_on_master( { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "epoch": epoch, "args": args, }, checkpoint_path, ) test_stats, coco_evaluator = evaluate( model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir, dset_file=args.dataset_file, ) log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, **{f"test_{k}": v for k, v in test_stats.items()}, "epoch": epoch, "n_parameters": n_parameters, } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") # for evaluation logs if coco_evaluator is not None: (output_dir / "eval").mkdir(exist_ok=True) if "bbox" in coco_evaluator.coco_eval: filenames = ["latest.pth"] if epoch % 50 == 0: filenames.append(f"{epoch:03}.pth") for name in filenames: torch.save( coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name, ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time {}".format(total_time_str))
def train(model, optimizer, scheduler, global_step, train_dataset, dev_dataset, opt, collator, best_eval_loss): if opt.is_main: try: tb_logger = torch.utils.tensorboard.SummaryWriter( Path(opt.checkpoint_dir) / opt.name) except: tb_logger = None logger.warning('Tensorboard is not available.') train_sampler = DistributedSampler( train_dataset) if opt.is_distributed else RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=opt.per_gpu_batch_size, drop_last=True, num_workers=10, collate_fn=collator) loss, curr_loss = 0.0, 0.0 epoch = 1 model.train() while global_step < opt.total_steps: if opt.is_distributed > 1: train_sampler.set_epoch(epoch) epoch += 1 for i, batch in enumerate(train_dataloader): global_step += 1 (idx, question_ids, question_mask, passage_ids, passage_mask, gold_score) = batch _, _, _, train_loss = model( question_ids=question_ids.cuda(), question_mask=question_mask.cuda(), passage_ids=passage_ids.cuda(), passage_mask=passage_mask.cuda(), gold_score=gold_score.cuda(), ) train_loss.backward() if global_step % opt.accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip) optimizer.step() scheduler.step() model.zero_grad() train_loss = src.util.average_main(train_loss, opt) curr_loss += train_loss.item() if global_step % opt.eval_freq == 0: eval_loss, inversions, avg_topk, idx_topk = evaluate( model, dev_dataset, collator, opt) if eval_loss < best_eval_loss: best_eval_loss = eval_loss if opt.is_main: src.util.save(model, optimizer, scheduler, global_step, best_eval_loss, opt, dir_path, 'best_dev') model.train() if opt.is_main: log = f"{global_step} / {opt.total_steps}" log += f" -- train: {curr_loss/opt.eval_freq:.6f}" log += f", eval: {eval_loss:.6f}" log += f", inv: {inversions:.1f}" log += f", lr: {scheduler.get_last_lr()[0]:.6f}" for k in avg_topk: log += f" | avg top{k}: {100*avg_topk[k]:.1f}" for k in idx_topk: log += f" | idx top{k}: {idx_topk[k]:.1f}" logger.info(log) if tb_logger is not None: tb_logger.add_scalar("Evaluation", eval_loss, global_step) tb_logger.add_scalar("Training", curr_loss / (opt.eval_freq), global_step) curr_loss = 0 if opt.is_main and global_step % opt.save_freq == 0: src.util.save(model, optimizer, scheduler, global_step, best_eval_loss, opt, dir_path, f"step-{global_step}") if global_step > opt.total_steps: break
class CTLTrainer(Trainer): def __init__( self, model: nn.Module, train_dataset: TSBaseDataset, valid_dataset: TSBaseDataset, test_dataset: TSBaseDataset, optimizer, evaluator: MetricEvaluator, criterion, config, ): self.config = config self._stop_training = False self.metrics = {} callbacks = [ hydra.utils.call(callback_config) for callback_config in self.config.trainer.callback.values() ] self.callbacks = CTLCallbackContainer(self, callbacks) self.world_size = self.config.device.get("world_size", 1) train_dataset = sample_data( train_dataset, self.config.dataset.get("train_samples", -1)) valid_dataset = sample_data( valid_dataset, self.config.dataset.get("valid_samples", -1)) self.valid_dataset_len = len(valid_dataset) self.train_dataset_len = len(train_dataset) self.train_sampler = None self.valid_sampler = None if self.world_size > 1: local_rank = int( self.config.device.get("local_rank", os.environ.get("LOCAL_RANK", 0))) self.device = get_device(local_rank, self.config.device.get("name", "cpu")) self.is_distributed = init_distributed( int( self.config.device.get("world_size", os.environ.get("WORLD_SIZE", 1)))) torch.cuda.synchronize() self.train_sampler = DistributedSampler(train_dataset, config.device.world_size, seed=config.trainer.get( "seed", 0), drop_last=True) self.valid_sampler = DistributedSampler(valid_dataset, config.device.world_size, seed=config.trainer.get( "seed", 0), drop_last=False) elif self.config.device.get("local_rank", None): self.device = get_device(self.config.device.get("local_rank"), self.config.device.get("name", "cpu")) else: self.device = torch.device(self.config.device.get("name", "cpu")) self.logger = setup_logger(self.config) self.optimizer = optimizer self.amp_enabled = self.config.trainer.get("AMP", False) self.model = model.to(self.device) if config.trainer.get("ema", None) is not None: self.ema = ModelEmaV2(config, model, self.device) else: self.ema = None if self.amp_enabled: self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O2", loss_scale="dynamic") if self.world_size > 1: self.model = DDP(self.model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) # TODO: line below has to go somewhere else. Or use default print. Logging module alters std streams which prevents us from # capturing their outputs. # log(config.pretty()) # XXX: Not sure about this. Maybe this should be isolated in collate_fn inside a DataLoader. Or maybe we should abstract it away in data_utils? # For sure we have to rename this. This suggests that masked target is somehow different from # regular target. self.train_target = "target_masked" if config.model.get( "train_target_mask", True) else "target" self.eval_target = "target_masked" if config.model.get( "eval_target_mask", True) else "target" self.test_target = "target_masked" if config.model.get( "test_target_mask", True) else "target" if self.config.dataset.get("graph", False) and self.config.model.get( "graph_eligible", False): def _collate_graph(samples, target): batch = dgl.batch(samples) labels = batch.ndata["target"] # XXX: we need discuss how to do this neatly if target == "target_masked": labels = labels[:, self.config.dataset.encoder_length:, :] return batch, labels _collate = _collate_graph else: def _collate_dict(samples, target): batch = default_collate(samples) labels = batch["target"] if target == "target_masked": labels = labels[:, self.config.dataset.encoder_length:, :] return batch, labels _collate = _collate_dict self.train_dataloader = DataLoader( train_dataset, batch_size=self.config.trainer.batch_size, num_workers=self.config.trainer.num_workers, sampler=self.train_sampler, shuffle=True if self.train_sampler is None else False, pin_memory=True, collate_fn=partial(_collate, target=self.train_target), ) self.valid_dataloader = DataLoader( valid_dataset, batch_size=self.config.trainer.batch_size, num_workers=self.config.trainer.num_workers, sampler=self.valid_sampler, shuffle=True if self.valid_sampler is None else False, pin_memory=True, collate_fn=partial(_collate, target=self.eval_target), ) self.test_dataloader = DataLoader( test_dataset, batch_size=self.config.trainer.batch_size, num_workers=1, pin_memory=True, collate_fn=partial(_collate, target=self.test_target), ) if self.config.get("scheduler", None): self.scheduler = hydra.utils.instantiate(self.config.scheduler, optimizer) else: self.scheduler = None self.evaluator = evaluator self.criterion = criterion self.log_path = self.config.get("log_path", os.getcwd()) self.global_step = 0 self.epoch = 0 self.preds_train_output_selector = config.model.get( "preds_train_output_selector", -1) self.preds_eval_output_selector = config.model.get( "preds_eval_output_selector", -1) self.preds_test_output_selector = config.model.get( "preds_test_output_selector", -1) model_ref = self.model.module if self.world_size > 1 else self.model test_method_name = config.model.get("test_method", "__call__") self.test_method = getattr(model_ref, test_method_name) checkpoint_path = config.trainer.get("checkpoint_path", None) maybe_restore_checkpoint(self, checkpoint_path) def assess_valid(self): self.model.eval() with torch.no_grad(): running_losses = 0 for i, (batch, labels) in enumerate(self.valid_dataloader): batch = to_device(batch, device=self.device) labels = to_device(labels, device=self.device) if self.ema: preds = self.ema.module(batch) else: preds = self.model(batch) if self.preds_eval_output_selector >= 0: preds = preds[..., self.preds_eval_output_selector:self. preds_eval_output_selector + 1] losses = self.criterion(preds, labels) losses = reduce_tensor(losses, self.world_size).detach() running_losses += losses running_losses = running_losses / (len(self.valid_dataloader.dataset) / self.config.trainer.batch_size) if len(running_losses.size()) < 1: running_losses = running_losses.unsqueeze(0) running_losses = [loss.item() for loss in running_losses] data = {"val_loss": sum(running_losses)} for i, elem in enumerate(running_losses): data["val_loss_component_" + str(i)] = elem self.logger.log(step=self.global_step, data=data, verbosity=dllogger.Verbosity.VERBOSE) self.model.train() return sum(running_losses) def train(self): self.callbacks.on_train_begin() self.global_step = 0 for epoch in range(self.epoch, self.config.trainer.num_epochs): self.callbacks.on_epoch_begin(epoch) self.logger.log(step=self.global_step, data={"epoch": epoch}, verbosity=dllogger.Verbosity.VERBOSE) for i, (batch, labels) in enumerate(self.train_dataloader): self.callbacks.on_batch_begin(i) self.optimizer.zero_grad() batch = to_device(batch, device=self.device) labels = to_device(labels, device=self.device) preds = self.model(batch) if self.preds_train_output_selector >= 0: preds = preds[..., self.preds_train_output_selector:self. preds_train_output_selector + 1] losses = self.criterion(preds, labels) loss = losses.sum() if self.amp_enabled: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() self.optimizer.step() losses = reduce_tensor(losses, self.world_size, average=True) if len(losses.size()) < 1: losses = [losses] losses = [loss.item() for loss in losses] data = {"loss": loss.item()} for k, v in enumerate(losses): data["loss_component_" + str(k)] = v self.logger.log(step=self.global_step, data=data, verbosity=dllogger.Verbosity.VERBOSE) if self.config.optimizer.get("gradient_norm", 0.0) > 0: nn.utils.clip_grad_norm( self.model.parameters(), self.config.optimizer.gradient_norm) # XXX: shouldn't we move logging to a callback? if self.global_step % self.config.trainer.log_interval == 0: self.logger.flush() self.global_step += 1 self.callbacks.on_batch_end(i, logs=data) if self.ema: self.ema.update(self.model) if self.scheduler: self.scheduler.step() self.callbacks.on_valid_begin(epoch) validation_loss = self.assess_valid() data = {"val_loss": validation_loss} self.callbacks.on_valid_end(epoch, logs=data) if is_main_process(): save_checkpoint(self, checkpoint_dir=self.log_path) if self.train_sampler: self.train_sampler.set_epoch(epoch) self.valid_sampler.set_epoch(epoch) self.callbacks.on_epoch_end(epoch, logs=data) if self._stop_training: break self.callbacks.on_train_end(logs=self.metrics) def evaluate(self): self.callbacks.on_evaluate_begin() maybe_restore_checkpoint( self, os.path.join(self.log_path, "best_checkpoint.pth.tar")) self.model.eval() with torch.no_grad(): preds_full = [] labels_full = [] weights_full = [] ids_full = [] for i, (batch, labels) in enumerate(self.test_dataloader): batch = to_device(batch, device=self.device) labels = to_device(labels, device=self.device) if self.config.evaluator.get("use_weights", False): weights = batch["weight"] else: weights = None # XXX we should abstract this away ids = batch.ndata["id"] if isinstance( batch, dgl.DGLGraph) else batch["id"] ids = ids[:, 0, ...] # Assumes that time dimension is at index 1. We don't check whether te examle is constructed correctly labels_full.append(labels) weights_full.append(weights) preds = self.test_method(batch) if self.preds_test_output_selector >= 0: preds = preds[..., self.preds_test_output_selector:self. preds_test_output_selector + 1] ids_full.append(ids) preds_full.append(preds) preds_full = torch.cat(preds_full, dim=0).cpu().numpy() labels_full = torch.cat(labels_full, dim=0).cpu().numpy() if self.config.evaluator.get("use_weights", False): weights_full = torch.cat(weights_full).cpu().numpy() else: weights_full = np.zeros((0, 0)) ids_full = torch.cat(ids_full).cpu().numpy() eval_metrics = self.evaluator(labels_full, preds_full, weights_full, ids_full) self.metrics.update(eval_metrics) self.logger.log( step=[], data={k: float(v) for k, v in self.metrics.items()}, verbosity=dllogger.Verbosity.VERBOSE) self.callbacks.on_evaluate_end( logs=round_dict(self.metrics, decimal=3)) return round_dict(self.metrics, decimal=3)
def run(self, dataset: torch.utils.data.Dataset, memory_set: torch.utils.data.Dataset = None, query_set: torch.utils.data.Dataset = None, save_every: int = 100, **kwargs): if not self.prepared: raise RuntimeError("Training not prepared.") # DataLoader (for self-supervised pre-training) sampler = DistributedSampler(dataset) if self.distributed else None shuffle = not self.distributed data_loader = DataLoader( dataset, batch_size=self.batch_size, sampler=sampler, shuffle=shuffle, num_workers=self.num_workers, drop_last=True, pin_memory=True ) # DataLoader (for supervised evaluation) if (memory_set is not None) and (query_set is not None): memory_loader = DataLoader(memory_set, batch_size=self.batch_size*2, num_workers=self.num_workers) query_loader = DataLoader(query_set, batch_size=self.batch_size*2) knn_eval = True else: query_loader = None memory_loader = None knn_eval = False # Logging logger = kwargs.get('logger', None) for epoch in range(1, self.epochs + 1): if self.distributed and (sampler is not None): sampler.set_epoch(epoch) # Train history = self.train(data_loader) log = " | ".join([f"{k} : {v:.4f}" for k, v in history.items()]) # Evaluate if (self.local_rank == 0) and knn_eval: knn_k = kwargs.get('knn_k', [5, 200]) knn = KNNEvaluator(knn_k, num_classes=query_loader.dataset.num_classes) knn_scores = knn.evaluate(self.net_q, memory_loader=memory_loader, query_loader=query_loader) for k, score in knn_scores.items(): log += f" | knn@{k}: {score*100:.2f}%" else: knn_scores = None # Logging if logger is not None: logger.info(f"Epoch [{epoch:>4}/{self.epochs:>4}] - " + log) # TensorBoard if self.writer is not None: for k, v in history.items(): self.writer.add_scalar(k, v, global_step=epoch) if knn_scores is not None: for k, score in knn_scores.items(): self.writer.add_scalar(f'knn@{k}', score, global_step=epoch) if self.scheduler is not None: lr = self.scheduler.get_last_lr()[0] self.writer.add_scalar('lr', lr, global_step=epoch) if (epoch % save_every == 0) & (self.local_rank == 0): ckpt = os.path.join(self.ckpt_dir, f"ckpt.{epoch}.pth.tar") self.save_checkpoint(ckpt, epoch=epoch, history=history) if self.scheduler is not None: self.scheduler.step()
def main(args): utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(split='train', args=args) dataset_val = build_dataset(split='val', args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) base_ds = get_coco_api_from_dataset(coco_val) elif args.dataset_file == "coco": base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: checkpoint = torch.load(args.frozen_weights, map_location='cpu') model_without_ddp.detr.load_state_dict(checkpoint['model']) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 # if args.eval: # if 'coco' in args.dataset_file: # test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, # data_loader_val, base_ds, device, args.output_dir) # if args.output_dir: # utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") # elif 'anet' == args.dataset_file: # evaluate3d(model, postprocessors, data_loader_val, device, epoch=0) # return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] # extra checkpoint before LR drop and every 100 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) if epoch % args.eval_freq == 0: if 'coco' in args.dataset_file: test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) elif 'anet' == args.dataset_file: evaluate3d(model, postprocessors, data_loader_val, device, epoch)
class ForcesTrainer(BaseTrainer): """ Trainer class for the Structure to Energy & Force (S2EF) and Initial State to Relaxed State (IS2RS) tasks. .. note:: Examples of configurations for task, model, dataset and optimizer can be found in `configs/ocp_s2ef <https://github.com/Open-Catalyst-Project/baselines/tree/master/configs/ocp_is2re/>`_ and `configs/ocp_is2rs <https://github.com/Open-Catalyst-Project/baselines/tree/master/configs/ocp_is2rs/>`_. Args: task (dict): Task configuration. model (dict): Model configuration. dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset. optimizer (dict): Optimizer configuration. identifier (str): Experiment identifier that is appended to log directory. run_dir (str, optional): Path to the run directory where logs are to be saved. (default: :obj:`None`) is_debug (bool, optional): Run in debug mode. (default: :obj:`False`) is_vis (bool, optional): Run in debug mode. (default: :obj:`False`) is_hpo (bool, optional): Run hyperparameter optimization with Ray Tune. (default: :obj:`False`) print_every (int, optional): Frequency of printing logs. (default: :obj:`100`) seed (int, optional): Random number seed. (default: :obj:`None`) logger (str, optional): Type of logger to be used. (default: :obj:`tensorboard`) local_rank (int, optional): Local rank of the process, only applicable for distributed training. (default: :obj:`0`) amp (bool, optional): Run using automatic mixed precision. (default: :obj:`False`) """ def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, ): super().__init__( task=task, model=model, dataset=dataset, optimizer=optimizer, identifier=identifier, run_dir=run_dir, is_debug=is_debug, is_vis=is_vis, is_hpo=is_hpo, print_every=print_every, seed=seed, logger=logger, local_rank=local_rank, amp=amp, cpu=cpu, name="s2ef", ) def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater( 1 if not self.cpu else 0, self.config["model_attributes"].get("otf_graph", False), ) if self.config["task"]["dataset"] == "trajectory_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=True, ) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.train_sampler, ) self.val_loader = self.test_loader = None self.val_sampler = self.test_sampler = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_sampler = DistributedSampler( self.val_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.val_sampler, ) if "test_dataset" in self.config: self.test_dataset = registry.get_dataset_class( self.config["task"]["dataset"])( self.config["test_dataset"]) self.test_sampler = DistributedSampler( self.test_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.test_loader = DataLoader( self.test_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.test_sampler, ) if "relax_dataset" in self.config["task"]: assert os.path.isfile(self.config["task"]["relax_dataset"]["src"]) self.relax_dataset = registry.get_dataset_class( "single_point_lmdb")(self.config["task"]["relax_dataset"]) self.relax_sampler = DistributedSampler( self.relax_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.relax_loader = DataLoader( self.relax_dataset, batch_size=self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.relax_sampler, ) self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", False): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: self.normalizers["target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if self.config["model_attributes"].get("regress_forces", True): if self.config["dataset"].get("normalize_labels", False): if "grad_target_mean" in self.config["dataset"]: self.normalizers["grad_target"] = Normalizer( mean=self.config["dataset"]["grad_target_mean"], std=self.config["dataset"]["grad_target_std"], device=self.device, ) else: self.normalizers["grad_target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) self.normalizers["grad_target"].mean.fill_(0) if (self.is_vis and self.config["task"]["dataset"] != "qm9" and distutils.is_master()): # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots) # Takes in a new data source and generates predictions on it. @torch.no_grad() def predict(self, data_loader, per_image=True, results_file=None, disable_tqdm=True): if distutils.is_master() and not disable_tqdm: print("### Predicting on test.") assert isinstance( data_loader, ( torch.utils.data.dataloader.DataLoader, torch_geometric.data.Batch, ), ) rank = distutils.get_rank() if isinstance(data_loader, torch_geometric.data.Batch): data_loader = [[data_loader]] self.model.eval() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) self.normalizers["grad_target"].to(self.device) predictions = {"id": [], "energy": [], "forces": [], "chunk_idx": []} for i, batch_list in tqdm( enumerate(data_loader), total=len(data_loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch_list) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"]) if per_image: systemids = [ str(i) + "_" + str(j) for i, j in zip( batch_list[0].sid.tolist(), batch_list[0].fid.tolist()) ] predictions["id"].extend(systemids) predictions["energy"].extend(out["energy"].to( torch.float16).tolist()) batch_natoms = torch.cat( [batch.natoms for batch in batch_list]) batch_fixed = torch.cat([batch.fixed for batch in batch_list]) forces = out["forces"].cpu().detach().to(torch.float16) per_image_forces = torch.split(forces, batch_natoms.tolist()) per_image_forces = [ force.numpy() for force in per_image_forces ] # evalAI only requires forces on free atoms if results_file is not None: _per_image_fixed = torch.split(batch_fixed, batch_natoms.tolist()) _per_image_free_forces = [ force[(fixed == 0).tolist()] for force, fixed in zip( per_image_forces, _per_image_fixed) ] _chunk_idx = np.array([ free_force.shape[0] for free_force in _per_image_free_forces ]) per_image_forces = _per_image_free_forces predictions["chunk_idx"].extend(_chunk_idx) predictions["forces"].extend(per_image_forces) else: predictions["energy"] = out["energy"].detach() predictions["forces"] = out["forces"].detach() return predictions predictions["forces"] = np.array(predictions["forces"]) predictions["chunk_idx"] = np.array(predictions["chunk_idx"]) predictions["energy"] = np.array(predictions["energy"]) predictions["id"] = np.array(predictions["id"]) self.save_results(predictions, results_file, keys=["energy", "forces", "chunk_idx"]) return predictions def train(self): eval_every = self.config["optim"].get("eval_every", len(self.train_loader)) primary_metric = self.config["task"].get( "primary_metric", self.evaluator.task_primary_metric[self.name]) self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 iters = 0 self.metrics = {} start_epoch = self.start_step // len(self.train_loader) for epoch in range(start_epoch, self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch) skip_steps = 0 if epoch == start_epoch and start_epoch > 0: skip_steps = start_epoch % len(self.train_loader) train_loader_iter = iter(self.train_loader) for i in range(skip_steps, len(self.train_loader)): self.model.train() current_epoch = epoch + (i + 1) / len(self.train_loader) current_step = epoch * len(self.train_loader) + (i + 1) # Get a batch. batch = next(train_loader_iter) # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, self.metrics, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Log metrics. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update({ "lr": self.scheduler.get_lr(), "epoch": current_epoch, "step": current_step, }) if (current_step % self.config["cmd"]["print_every"] == 0 and distutils.is_master() and not self.is_hpo): log_str = [ "{}: {:.2e}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=current_step, split="train", ) iters += 1 # Evaluate on val set every `eval_every` iterations. if iters % eval_every == 0: if self.val_loader is not None: val_metrics = self.validate( split="val", epoch=epoch - 1 + (i + 1) / len(self.train_loader), ) if ("mae" in primary_metric and val_metrics[primary_metric]["metric"] < self.best_val_metric) or ( val_metrics[primary_metric]["metric"] > self.best_val_metric): self.best_val_metric = val_metrics[primary_metric][ "metric"] self.save(current_epoch, current_step, val_metrics) if self.test_loader is not None: self.predict( self.test_loader, results_file="predictions", disable_tqdm=False, ) if self.is_hpo: self.hpo_update( current_epoch, current_step, self.metrics, val_metrics, ) else: self.save(current_epoch, current_step, self.metrics) if self.scheduler.scheduler_type == "ReduceLROnPlateau": if iters % eval_every == 0: self.scheduler.step( metrics=val_metrics[primary_metric]["metric"], ) else: self.scheduler.step() torch.cuda.empty_cache() self.train_dataset.close_db() if "val_dataset" in self.config: self.val_dataset.close_db() if "test_dataset" in self.config: self.test_dataset.close_db() def _forward(self, batch_list): # forward pass. if self.config["model_attributes"].get("regress_forces", True): out_energy, out_forces = self.model(batch_list) else: out_energy = self.model(batch_list) if out_energy.shape[-1] == 1: out_energy = out_energy.view(-1) out = { "energy": out_energy, } if self.config["model_attributes"].get("regress_forces", True): out["forces"] = out_forces return out def _compute_loss(self, out, batch_list): loss = [] # Energy loss. energy_target = torch.cat( [batch.y.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): energy_target = self.normalizers["target"].norm(energy_target) energy_mult = self.config["optim"].get("energy_coefficient", 1) loss.append(energy_mult * self.criterion(out["energy"], energy_target)) # Force loss. if self.config["model_attributes"].get("regress_forces", True): force_target = torch.cat( [batch.force.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): force_target = self.normalizers["grad_target"].norm( force_target) tag_specific_weights = self.config["task"].get( "tag_specific_weights", []) if tag_specific_weights != []: # handle tag specific weights as introduced in forcenet assert len(tag_specific_weights) == 3 batch_tags = torch.cat( [ batch.tags.float().to(self.device) for batch in batch_list ], dim=0, ) weight = torch.zeros_like(batch_tags) weight[batch_tags == 0] = tag_specific_weights[0] weight[batch_tags == 1] = tag_specific_weights[1] weight[batch_tags == 2] = tag_specific_weights[2] loss_force_list = torch.abs(out["forces"] - force_target) train_loss_force_unnormalized = torch.sum(loss_force_list * weight.view(-1, 1)) train_loss_force_normalizer = 3.0 * weight.sum() # add up normalizer to obtain global normalizer distutils.all_reduce(train_loss_force_normalizer) # perform loss normalization before backprop train_loss_force_normalized = train_loss_force_unnormalized * ( distutils.get_world_size() / train_loss_force_normalizer) loss.append(train_loss_force_normalized) else: # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list]) mask = fixed == 0 loss.append(force_mult * self.criterion( out["forces"][mask], force_target[mask])) else: loss.append(force_mult * self.criterion(out["forces"], force_target)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): natoms = torch.cat( [batch.natoms.to(self.device) for batch in batch_list], dim=0) target = { "energy": torch.cat([batch.y.to(self.device) for batch in batch_list], dim=0), "forces": torch.cat([batch.force.to(self.device) for batch in batch_list], dim=0), "natoms": natoms, } out["natoms"] = natoms if self.config["task"].get("eval_on_free_atoms", True): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list]) mask = fixed == 0 out["forces"] = out["forces"][mask] target["forces"] = target["forces"][mask] s_idx = 0 natoms_free = [] for natoms in target["natoms"]: natoms_free.append( torch.sum(mask[s_idx:s_idx + natoms]).item()) s_idx += natoms target["natoms"] = torch.LongTensor(natoms_free).to(self.device) out["natoms"] = torch.LongTensor(natoms_free).to(self.device) if self.config["dataset"].get("normalize_labels", False): out["energy"] = self.normalizers["target"].denorm(out["energy"]) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"]) metrics = evaluator.eval(out, target, prev_metrics=metrics) return metrics def run_relaxations(self, split="val", epoch=None): print("### Running ML-relaxations") self.model.eval() evaluator, metrics = Evaluator(task="is2rs"), {} if hasattr(self.relax_dataset[0], "pos_relaxed") and hasattr( self.relax_dataset[0], "y_relaxed"): split = "val" else: split = "test" ids = [] relaxed_positions = [] chunk_idx = [] for i, batch in tqdm(enumerate(self.relax_loader), total=len(self.relax_loader)): relaxed_batch = ml_relax( batch=batch, model=self, steps=self.config["task"].get("relaxation_steps", 200), fmax=self.config["task"].get("relaxation_fmax", 0.0), relax_opt=self.config["task"]["relax_opt"], device=self.device, transform=None, ) if self.config["task"].get("write_pos", False): systemids = [str(i) for i in relaxed_batch.sid.tolist()] natoms = relaxed_batch.natoms.tolist() positions = torch.split(relaxed_batch.pos, natoms) batch_relaxed_positions = [pos.tolist() for pos in positions] relaxed_positions += batch_relaxed_positions chunk_idx += natoms ids += systemids if split == "val": mask = relaxed_batch.fixed == 0 s_idx = 0 natoms_free = [] for natoms in relaxed_batch.natoms: natoms_free.append( torch.sum(mask[s_idx:s_idx + natoms]).item()) s_idx += natoms target = { "energy": relaxed_batch.y_relaxed, "positions": relaxed_batch.pos_relaxed[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } prediction = { "energy": relaxed_batch.y, "positions": relaxed_batch.pos[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } metrics = evaluator.eval(prediction, target, metrics) if self.config["task"].get("write_pos", False): rank = distutils.get_rank() pos_filename = os.path.join(self.config["cmd"]["results_dir"], f"relaxed_pos_{rank}.npz") np.savez_compressed( pos_filename, ids=ids, pos=np.array(relaxed_positions, dtype=object), chunk_idx=chunk_idx, ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"relaxed_pos_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) gather_results["pos"].extend(rank_results["pos"]) gather_results["chunk_idx"].extend( rank_results["chunk_idx"]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] gather_results["pos"] = np.concatenate( np.array(gather_results["pos"])[idx]) gather_results["chunk_idx"] = np.cumsum( np.array(gather_results["chunk_idx"])[idx] )[:-1] # np.split does not need last idx, assumes n-1:end print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results) if split == "val": aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = ( aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics # Make plots. log_dict = {k: metrics[k]["metric"] for k in metrics} if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) if distutils.is_master(): print(metrics)
def main(args): utils.init_distributed_mode(args) if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(image_set='train', args=args) dataset_val = build_dataset(image_set='val', args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_val = DataLoader(dataset_val, batch_size=1, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) base_ds = get_coco_api_from_dataset(coco_val) else: base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: io.load_frozen(args, model_without_ddp) output_dir = Path(args.output_dir) if args.resume: io.resume(args, model_without_ddp, optimizer, lr_scheduler) elif args.finetune: io.finetune(args, model_without_ddp) if args.eval: if args.output_dir and utils.is_main_process(): io.init_wandb(args.dataset_file + "-detr-eval", model, args, n_parameters=n_parameters) test_stats, evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) if args.output_dir: io.save_on_master(evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") return print("Start training") start_time = time.time() if args.output_dir and utils.is_main_process(): io.init_wandb(args.dataset_file + "-detr", model, args, n_parameters=n_parameters) for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() if args.output_dir: # extra checkpoint before LR drop and every 100 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: io.save_checkpoint(args, model_without_ddp, optimizer, lr_scheduler, epoch) test_stats, evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir, epoch) if utils.is_main_process() and args.output_dir: io.log_wandb(train_stats, test_stats) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) # save final model if utils.is_main_process() and args.output_dir: io.save_on_master(model_without_ddp, output_dir / "model_final.pth") print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # Data loading code print('Loading data') dataset_train = build_dataset(args.train_set, args.dataset_year, args) dataset_val = build_dataset(args.val_set, args.dataset_year, args) base_ds = get_coco_api_from_dataset(dataset_val) print('Creating data loaders') if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler( sampler_train, args.batch_size, drop_last=True, ) data_loader_train = DataLoader( dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers, ) data_loader_val = DataLoader( dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers, ) print('Creating model, always set args.return_criterion be True') args.return_criterion = True model = yolov5s(num_classes=args.num_classes) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], ) model_without_ddp = model.module params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD( params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, ) if args.lr_scheduler == 'cosine': lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.t_max) elif args.lr_scheduler == 'multi-step': lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.lr_steps, gamma=args.lr_gamma, ) else: raise ValueError(f'scheduler {args.lr_scheduler} not supported') output_dir = Path(args.output_dir) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: evaluate(model, data_loader_val, base_ds, device) return print('Start training') start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_one_epoch(model, optimizer, data_loader_train, device, epoch, args.print_freq) lr_scheduler.step() if args.output_dir: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'args': args, 'epoch': epoch, }, output_dir.joinpath(f'model_{epoch}.pth'), ) # evaluate after every epoch # evaluate(model, criterion, data_loader_val, device=device) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f'Training time {total_time_str}')
def main(args): # Init distributed mode dist.init_distributed_mode(args) # Update dataset specific configs if args.dataset_config is not None: # https://stackoverflow.com/a/16878364 d = vars(args) with open(args.dataset_config, "r") as f: cfg = json.load(f) d.update(cfg) print("git:\n {}\n".format(utils.get_sha())) # Segmentation related if args.mask_model != "none": args.masks = True if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) output_dir = Path(args.output_dir) # fix the seed for reproducibility seed = args.seed + dist.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.set_deterministic(True) # Build the model model, criterion, contrastive_criterion, qa_criterion, weight_dict = build_model( args) model.to(device) assert ( criterion is not None or qa_criterion is not None ), "Error: should train either detection or question answering (or both)" # Get a copy of the model for exponential moving averaged version of the model model_ema = deepcopy(model) if args.ema else None model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("number of params:", n_parameters) # Set up optimizers param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and "text_encoder" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "text_encoder" in n and p.requires_grad ], "lr": args.text_encoder_lr, }, ] if args.optimizer == "sgd": optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) elif args.optimizer in ["adam", "adamw"]: optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) else: raise RuntimeError(f"Unsupported optimizer {args.optimizer}") # Train dataset if len(args.combine_datasets) == 0 and not args.eval: raise RuntimeError("Please provide at least one training dataset") dataset_train, sampler_train, data_loader_train = None, None, None if not args.eval: dataset_train = ConcatDataset([ build_dataset(name, image_set="train", args=args) for name in args.combine_datasets ]) # To handle very big datasets, we chunk it into smaller parts. if args.epoch_chunks > 0: print( "Splitting the training set into {args.epoch_chunks} of size approximately " f" {len(dataset_train) // args.epoch_chunks}") chunks = torch.chunk(torch.arange(len(dataset_train)), args.epoch_chunks) datasets = [ torch.utils.data.Subset(dataset_train, chunk.tolist()) for chunk in chunks ] if args.distributed: samplers_train = [DistributedSampler(ds) for ds in datasets] else: samplers_train = [ torch.utils.data.RandomSampler(ds) for ds in datasets ] batch_samplers_train = [ torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) for sampler_train in samplers_train ] assert len(batch_samplers_train) == len(datasets) data_loaders_train = [ DataLoader( ds, batch_sampler=batch_sampler_train, collate_fn=partial(utils.collate_fn, False), num_workers=args.num_workers, ) for ds, batch_sampler_train in zip(datasets, batch_samplers_train) ] else: if args.distributed: sampler_train = DistributedSampler(dataset_train) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) batch_sampler_train = torch.utils.data.BatchSampler( sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader( dataset_train, batch_sampler=batch_sampler_train, collate_fn=partial(utils.collate_fn, False), num_workers=args.num_workers, ) # Val dataset if len(args.combine_datasets_val) == 0: raise RuntimeError("Please provide at leas one validation dataset") Val_all = namedtuple(typename="val_data", field_names=[ "dataset_name", "dataloader", "base_ds", "evaluator_list" ]) val_tuples = [] for dset_name in args.combine_datasets_val: dset = build_dataset(dset_name, image_set="val", args=args) sampler = (DistributedSampler(dset, shuffle=False) if args.distributed else torch.utils.data.SequentialSampler(dset)) dataloader = DataLoader( dset, args.batch_size, sampler=sampler, drop_last=False, collate_fn=partial(utils.collate_fn, False), num_workers=args.num_workers, ) base_ds = get_coco_api_from_dataset(dset) val_tuples.append( Val_all(dataset_name=dset_name, dataloader=dataloader, base_ds=base_ds, evaluator_list=None)) if args.frozen_weights is not None: if args.resume.startswith("https"): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location="cpu", check_hash=True) else: checkpoint = torch.load(args.resume, map_location="cpu") if "model_ema" in checkpoint and checkpoint["model_ema"] is not None: model_without_ddp.detr.load_state_dict(checkpoint["model_ema"], strict=False) else: model_without_ddp.detr.load_state_dict(checkpoint["model"], strict=False) if args.ema: model_ema = deepcopy(model_without_ddp) # Used for loading weights from another model and starting a training from scratch. Especially useful if # loading into a model with different functionality. if args.load: print("loading from", args.load) checkpoint = torch.load(args.load, map_location="cpu") if "model_ema" in checkpoint: model_without_ddp.load_state_dict(checkpoint["model_ema"], strict=False) else: model_without_ddp.load_state_dict(checkpoint["model"], strict=False) if args.ema: model_ema = deepcopy(model_without_ddp) # Used for resuming training from the checkpoint of a model. Used when training times-out or is pre-empted. if args.resume: if args.resume.startswith("https"): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location="cpu", check_hash=True) else: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) if not args.eval and "optimizer" in checkpoint and "epoch" in checkpoint: optimizer.load_state_dict(checkpoint["optimizer"]) args.start_epoch = checkpoint["epoch"] + 1 if args.ema: if "model_ema" not in checkpoint: print( "WARNING: ema model not found in checkpoint, resetting to current model" ) model_ema = deepcopy(model_without_ddp) else: model_ema.load_state_dict(checkpoint["model_ema"]) def build_evaluator_list(base_ds, dataset_name): """Helper function to build the list of evaluators for a given dataset""" evaluator_list = [] if args.no_detection: return evaluator_list iou_types = ["bbox"] if args.masks: iou_types.append("segm") evaluator_list.append( CocoEvaluator(base_ds, tuple(iou_types), useCats=False)) if "refexp" in dataset_name: evaluator_list.append(RefExpEvaluator(base_ds, ("bbox"))) if "clevrref" in dataset_name: evaluator_list.append(ClevrRefEvaluator(base_ds, ("bbox"))) if "flickr" in dataset_name: evaluator_list.append( FlickrEvaluator( args.flickr_dataset_path, subset="test" if args.test else "val", merge_boxes=args.GT_type == "merged", )) if "phrasecut" in dataset_name: evaluator_list.append( PhrasecutEvaluator( "test" if args.test else "miniv", ann_folder=args.phrasecut_orig_ann_path, output_dir=os.path.join(output_dir, "phrasecut_eval"), eval_mask=args.masks, )) return evaluator_list # Runs only evaluation, by default on the validation set unless --test is passed. if args.eval: test_stats = {} test_model = model_ema if model_ema is not None else model for i, item in enumerate(val_tuples): evaluator_list = build_evaluator_list(item.base_ds, item.dataset_name) postprocessors = build_postprocessors(args, item.dataset_name) item = item._replace(evaluator_list=evaluator_list) print(f"Evaluating {item.dataset_name}") curr_test_stats = evaluate( model=test_model, criterion=criterion, contrastive_criterion=contrastive_criterion, qa_criterion=qa_criterion, postprocessors=postprocessors, weight_dict=weight_dict, data_loader=item.dataloader, evaluator_list=item.evaluator_list, device=device, args=args, ) test_stats.update({ item.dataset_name + "_" + k: v for k, v in curr_test_stats.items() }) log_stats = { **{f"test_{k}": v for k, v in test_stats.items()}, "n_parameters": n_parameters, } print(log_stats) return # Runs training and evaluates after every --eval_skip epochs print("Start training") start_time = time.time() best_metric = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.epoch_chunks > 0: sampler_train = samplers_train[epoch % len(samplers_train)] data_loader_train = data_loaders_train[epoch % len(data_loaders_train)] print( f"Starting epoch {epoch // len(data_loaders_train)}, sub_epoch {epoch % len(data_loaders_train)}" ) else: print(f"Starting epoch {epoch}") if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch( model=model, criterion=criterion, contrastive_criterion=contrastive_criterion, qa_criterion=qa_criterion, data_loader=data_loader_train, weight_dict=weight_dict, optimizer=optimizer, device=device, epoch=epoch, args=args, max_norm=args.clip_max_norm, model_ema=model_ema, ) if args.output_dir: checkpoint_paths = [output_dir / "checkpoint.pth"] # extra checkpoint before LR drop and every 2 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 2 == 0: checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth") for checkpoint_path in checkpoint_paths: dist.save_on_master( { "model": model_without_ddp.state_dict(), "model_ema": model_ema.state_dict() if args.ema else None, "optimizer": optimizer.state_dict(), "epoch": epoch, "args": args, }, checkpoint_path, ) if epoch % args.eval_skip == 0: test_stats = {} test_model = model_ema if model_ema is not None else model for i, item in enumerate(val_tuples): evaluator_list = build_evaluator_list(item.base_ds, item.dataset_name) item = item._replace(evaluator_list=evaluator_list) postprocessors = build_postprocessors(args, item.dataset_name) print(f"Evaluating {item.dataset_name}") curr_test_stats = evaluate( model=test_model, criterion=criterion, contrastive_criterion=contrastive_criterion, qa_criterion=qa_criterion, postprocessors=postprocessors, weight_dict=weight_dict, data_loader=item.dataloader, evaluator_list=item.evaluator_list, device=device, args=args, ) test_stats.update({ item.dataset_name + "_" + k: v for k, v in curr_test_stats.items() }) else: test_stats = {} log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, **{f"test_{k}": v for k, v in test_stats.items()}, "epoch": epoch, "n_parameters": n_parameters, } if args.output_dir and dist.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") if epoch % args.eval_skip == 0: if args.do_qa: metric = test_stats["gqa_accuracy_answer_total_unscaled"] else: metric = np.mean([ v[1] for k, v in test_stats.items() if "coco_eval_bbox" in k ]) if args.output_dir and metric > best_metric: best_metric = metric checkpoint_paths = [output_dir / "BEST_checkpoint.pth"] # extra checkpoint before LR drop and every 100 epochs for checkpoint_path in checkpoint_paths: dist.save_on_master( { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "args": args, }, checkpoint_path, ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time {}".format(total_time_str))
class Brain: r"""Brain class abstracts away the details of data loops. The primary purpose of the `Brain` class is the implementation of the ``fit()`` method, which iterates epochs and datasets for the purpose of "fitting" a set of modules to a set of data. In order to use the ``fit()`` method, one should sub-class the ``Brain`` class and override any methods for which the default behavior does not match the use case. For a simple use case (e.g., training a single model with a single dataset) the only methods that need to be overridden are: * ``compute_forward()`` * ``compute_objectives()`` The example below illustrates how overriding these two methods is done. For more complicated use cases, such as multiple modules that need to be updated, the following methods can be overridden: * ``fit_batch()`` * ``evaluate_batch()`` Arguments --------- modules : dict of str:torch.nn.Module pairs These modules are passed to the optimizer by default if they have trainable parameters, and will have ``train()``/``eval()`` called on them. opt_class : torch.optim class A torch optimizer constructor that has takes only the list of parameters (e.g. a lambda or partial function definition). By default, this will be passed all modules in ``modules`` at the beginning of the ``fit()`` method. This behavior can be changed by overriding the ``configure_optimizers()`` method. hparams : dict Each key:value pair should consist of a string key and a hyperparameter that is used within the overridden methods. These will be accessible via an ``hparams`` attribute, using "dot" notation: e.g., self.hparams.model(x). run_opts : dict A set of options to change the runtime environment, including debug (bool) If ``True``, this will only iterate a few batches for all datasets, to ensure code runs without crashing. debug_batches (int) Number of batches to run in debug mode, Default ``2``. debug_epochs (int) Number of epochs to run in debug mode, Default ``2``. If a non-positive number is passed, all epochs are run. jit_module_keys (list of str) List of keys in ``modules`` that should be jit compiled. distributed_count (int) Number of devices to run on. distributed_backend (str) One of ``ddp_nccl``, ``ddp_gloo``, ``ddp_mpi``, ``data_parallel``. device (str) The location for performing computations. auto_mix_prec (bool) If ``True``, automatic mixed-precision is used. Activate it only with cuda. max_grad_norm (float) Default implementation of ``fit_batch()`` uses ``clip_grad_norm_`` with this value. Default: ``5``. nonfinite_patience (int) Number of times to ignore non-finite losses before stopping. Default: ``3``. noprogressbar (bool) Whether to turn off progressbar when training. Default: ``False``. ckpt_interval_minutes (float) Amount of time between saving intra-epoch checkpoints, in minutes, default: ``15.0``. If non-positive, these are not saved. checkpointer : speechbrain.Checkpointer By default, this will be used to load checkpoints, and will have the optimizer added to continue training if interrupted. Example ------- >>> from torch.optim import SGD >>> class SimpleBrain(Brain): ... def compute_forward(self, batch, stage): ... return self.modules.model(batch[0]) ... def compute_objectives(self, predictions, batch, stage): ... return torch.nn.functional.l1_loss(predictions, batch[0]) >>> model = torch.nn.Linear(in_features=10, out_features=10) >>> brain = SimpleBrain({"model": model}, opt_class=lambda x: SGD(x, 0.1)) >>> brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],)) """ def __init__( # noqa: C901 self, modules=None, opt_class=None, hparams=None, run_opts=None, checkpointer=None, ): self.opt_class = opt_class self.checkpointer = checkpointer # Arguments passed via the run opts dictionary run_opt_defaults = { "debug": False, "debug_batches": 2, "debug_epochs": 2, "device": "cpu", "data_parallel_count": -1, "data_parallel_backend": False, "distributed_launch": False, "distributed_backend": "nccl", "jit_module_keys": None, "auto_mix_prec": False, "max_grad_norm": 5.0, "nonfinite_patience": 3, "noprogressbar": False, "ckpt_interval_minutes": 0, } for arg, default in run_opt_defaults.items(): if run_opts is not None and arg in run_opts: if hparams is not None and arg in hparams: logger.info( "Info: " + arg + " arg overridden by command line input" ) setattr(self, arg, run_opts[arg]) else: # If any arg from run_opt_defaults exist in hparams and # not in command line args "run_opts" if hparams is not None and arg in hparams: logger.info( "Info: " + arg + " arg from hparam file is used" ) setattr(self, arg, hparams[arg]) else: setattr(self, arg, default) if self.data_parallel_backend and self.distributed_launch: sys.exit( "To use data_parallel backend, start your script with:\n\t" "python experiment.py hyperparams.yaml " "--data_parallel_backend=True --data_parallel_count=2" "To use DDP backend, start your script with:\n\t" "python -m torch.distributed.lunch [args]\n" "experiment.py hyperparams.yaml --distributed_launch=True " "--distributed_backend=nccl" ) # Switch to the right context if "cuda" in self.device: torch.cuda.set_device(int(self.device[-1])) # Put modules on the right device, accessible with dot notation self.modules = torch.nn.ModuleDict(modules).to(self.device) # Make hyperparams available with dot notation too if hparams is not None: self.hparams = SimpleNamespace(**hparams) # Checkpointer should point at a temporary directory in debug mode if ( self.debug and self.checkpointer is not None and hasattr(self.checkpointer, "checkpoints_dir") ): tempdir = tempfile.TemporaryDirectory() logger.info( "Since debug mode is active, switching checkpointer " f"output to temporary directory: {tempdir.name}" ) self.checkpointer.checkpoints_dir = pathlib.Path(tempdir.name) # Keep reference to tempdir as long as checkpointer exists self.checkpointer.tempdir = tempdir # Sampler should be handled by `make_dataloader` # or if you provide a DataLoader directly, you can set # this.train_sampler = your_sampler # to have your_sampler.set_epoch() called on each epoch. self.train_sampler = None # Automatic mixed precision init if self.auto_mix_prec: self.scaler = torch.cuda.amp.GradScaler() # List parameter count for the user total_params = sum( p.numel() for p in self.modules.parameters() if p.requires_grad ) if total_params > 0: clsname = self.__class__.__name__ fmt_num = sb.utils.logger.format_order_of_magnitude(total_params) logger.info(f"{fmt_num} trainable parameters in {clsname}") if self.distributed_launch: self.rank = int(os.environ["RANK"]) if not torch.distributed.is_initialized(): if self.rank > 0: sys.exit( " ================ WARNING ===============" "Please add sb.ddp_init_group() into your exp.py" "To use DDP backend, start your script with:\n\t" "python -m torch.distributed.launch [args]\n\t" "experiment.py hyperparams.yaml " "--distributed_launch=True --distributed_backend=nccl" ) else: logger.warn( "To use DDP, please add " "sb.utils.distributed.ddp_init_group() into your exp.py" ) logger.info( "Only the main process is alive, " "all other subprocess were killed." ) # force the models to start and remain synchronized torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Prepare iterating variables self.avg_train_loss = 0.0 self.step = 0 # Add this class to the checkpointer for intra-epoch checkpoints if self.checkpointer is not None: self.checkpointer.add_recoverable("brain", self) def compute_forward(self, batch, stage): """Forward pass, to be overridden by sub-classes. Arguments --------- batch : torch.Tensor or tensors An element from the dataloader, including inputs for processing. stage : Stage The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST Returns ------- torch.Tensor or Tensors The outputs after all processing is complete. Directly passed to ``compute_objectives()``. """ raise NotImplementedError def compute_objectives(self, predictions, batch, stage): """Compute loss, to be overridden by sub-classes. Arguments --------- predictions : torch.Tensor or Tensors The output tensor or tensors to evaluate. Comes directly from ``compute_forward()``. batch : torch.Tensor or tensors An element from the dataloader, including targets for comparison. stage : Stage The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST Returns ------- loss : torch.Tensor A tensor with the computed loss. """ raise NotImplementedError def on_stage_start(self, stage, epoch=None): """Gets called when a stage starts. Useful for defining class variables used during the stage. Arguments --------- stage : Stage The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST epoch : int The current epoch count. """ pass def on_stage_end(self, stage, stage_loss, epoch=None): """Gets called at the end of a stage. Useful for computing stage statistics, saving checkpoints, etc. Arguments --------- stage : Stage The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST stage_loss : float The average loss over the completed stage. epoch : int The current epoch count. """ pass def make_dataloader( self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs, ): """Creates DataLoaders for Datasets. This is used by ``fit()`` and ``evaluate()`` if they just receive Datasets. Alternatively, this can be called from outside the Brain subclass. In that case, the DataLoader should be passed to ``fit()`` in place of the dataset. The Stage.TRAIN DataLoader is handled specially. It has extra args for shuffle and drop_last. In DDP a DistributedSampler is created (unless the dataset is an IterableDataset). NOTE ---- Some important DataLoader arguments are passed via **loader_kwargs, e.g., batch_size, num_workers, pin_memory. NOTE ---- By default, ``evaluate()`` specifies ckpt_prefix=None to stop the test DataLoader being added to the checkpointer. If you need to add a recoverable after saving checkpoints (e.g., at test time, after checkpointing the training), and still be able to recover reasonably, you should probably specify ``allow_partial_load=True``. Arguments --------- dataset : Dataset A set of data to use to create data loader. If the Dataset is a DynamicItemDataset, PaddedBatch is used as the default collate_fn, unless specified in loader_kwargs. stage : Stage The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST ckpt_prefix : str, None Prefix to use for SaveableDataLoader Checkpoint name. The Stage name is added to this to create the full key. Set to None to not save the DataLoader. **loader_kwargs : dict Additional keyword arguments to the DataLoader. E.g., batch_size, num_workers, pin_memory. """ # TRAIN stage is handled specially. if stage == sb.Stage.TRAIN: loader_kwargs = self._train_loader_specifics(dataset, loader_kwargs) dataloader = sb.dataio.dataloader.make_dataloader( dataset, **loader_kwargs ) if ( self.checkpointer is not None and ckpt_prefix is not None and isinstance(dataloader, SaveableDataLoader) ): ckpt_key = ckpt_prefix + stage.name self.checkpointer.add_recoverable(ckpt_key, dataloader) return dataloader def _train_loader_specifics(self, dataset, loader_kwargs): sampler = loader_kwargs.get("sampler", None) # Shuffling should really only matter for the train stage. Shuffling # will also lead to more padding in batches if the order was otherwise # sorted by length. shuffle = loader_kwargs.get("shuffle", False) if shuffle and not self.distributed_launch: if sampler is not None: raise ValueError( "Cannot specify both shuffle=True " "and a sampler in loader_kwargs" ) sampler = ReproducibleRandomSampler(dataset) self.train_sampler = sampler loader_kwargs["sampler"] = self.train_sampler # Delete the shuffle flag, since you cannot specify both a sampler and # shuffling: del loader_kwargs["shuffle"] # Possibly make a DistributedSampler or a wrapper for some other sampler if self.distributed_launch and not isinstance(dataset, IterableDataset): drop_last = loader_kwargs.get("drop_last", False) # num_replicas arg is equal to world_size # and retrieved automatically within # DistributedSampler obj. if sampler is not None: self.train_sampler = DistributedSamplerWrapper( sampler, rank=self.rank, drop_last=drop_last, shuffle=shuffle, ) # with DistributedSamplerWrapper, one must disable shuffling for dataloader loader_kwargs["shuffle"] = False elif loader_kwargs.get("batch_sampler") is None: # Currently to get here, shuffle == False, so not passing it. # Otherwise we'd have to handle deleting it (but it is already # deleted). self.train_sampler = DistributedSampler( dataset, rank=self.rank, shuffle=shuffle, drop_last=drop_last, ) # with DistributedSamplerWrapper, one must disable shuffling for dataloader loader_kwargs["shuffle"] = False else: # batch_sampler was specified # TODO: Could a DistributedSamplerWrapper actually work # just fine for wrapping a BatchSampler, as well? logger.warning( "Cannot automatically solve distributed sampling " "when using a BatchSampler." ) loader_kwargs["sampler"] = self.train_sampler elif self.distributed_launch and isinstance(dataset, IterableDataset): logger.warning( "Cannot automatically solve distributed sampling " "for IterableDataset." ) return loader_kwargs def on_fit_start(self): """Gets called at the beginning of ``fit()``, on multiple processes if ``distributed_count > 0`` and backend is ddp. Default implementation compiles the jit modules, initializes optimizers, and loads the latest checkpoint to resume training. """ # Run this *after* starting all processes since jit modules cannot be # pickled. self._compile_jit() # Wrap modules with parallel backend after jit self._wrap_distributed() # Initialize optimizers after parameters are configured self.init_optimizers() # Load latest checkpoint to resume training if interrupted if self.checkpointer is not None: self.checkpointer.recover_if_possible( device=torch.device(self.device) ) def init_optimizers(self): """Called during ``on_fit_start()``, initialize optimizers after parameters are fully configured (e.g. DDP, jit). The default implementation of this method depends on an optimizer class being passed at initialization that takes only a list of parameters (e.g., a lambda or a partial function definition). This creates a single optimizer that optimizes all trainable params. Override this class if there are multiple optimizers. """ if self.opt_class is not None: self.optimizer = self.opt_class(self.modules.parameters()) if self.checkpointer is not None: self.checkpointer.add_recoverable("optimizer", self.optimizer) def on_evaluate_start(self, max_key=None, min_key=None): """Gets called at the beginning of ``evaluate()`` Default implementation loads the best-performing checkpoint for evaluation, based on stored metrics. Arguments --------- max_key : str Key to use for finding best checkpoint (higher is better). By default, passed to ``self.checkpointer.recover_if_possible()``. min_key : str Key to use for finding best checkpoint (lower is better). By default, passed to ``self.checkpointer.recover_if_possible()``. """ # Recover best checkpoint for evaluation if self.checkpointer is not None: self.checkpointer.recover_if_possible( max_key=max_key, min_key=min_key, device=torch.device(self.device), ) def fit_batch(self, batch): """Fit one batch, override to do multiple updates. The default implementation depends on a few methods being defined with a particular behavior: * ``compute_forward()`` * ``compute_objectives()`` Also depends on having optimizers passed at initialization. Arguments --------- batch : list of torch.Tensors Batch of data to use for training. Default implementation assumes this batch has two elements: inputs and targets. Returns ------- detached loss """ # Managing automatic mixed precision if self.auto_mix_prec: self.optimizer.zero_grad() with torch.cuda.amp.autocast(): outputs = self.compute_forward(batch, Stage.TRAIN) loss = self.compute_objectives(outputs, batch, Stage.TRAIN) self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) if self.check_gradients(loss): self.scaler.step(self.optimizer) self.scaler.update() else: outputs = self.compute_forward(batch, Stage.TRAIN) loss = self.compute_objectives(outputs, batch, Stage.TRAIN) loss.backward() if self.check_gradients(loss): self.optimizer.step() self.optimizer.zero_grad() return loss.detach().cpu() def check_gradients(self, loss): """Check if gradients are finite and not too large. Automatically clips large gradients. Arguments --------- loss : tensor The loss tensor after ``backward()`` has been called but before the optimizers ``step()``. Returns ------- bool Whether or not the optimizer step should be carried out. """ if not torch.isfinite(loss): self.nonfinite_count += 1 # Print helpful debug info logger.warn(f"Loss is {loss}.") for p in self.modules.parameters(): if not torch.isfinite(p).all(): logger.warn("Parameter is not finite: " + str(p)) # Check if patience is exhausted if self.nonfinite_count > self.nonfinite_patience: raise ValueError( "Loss is not finite and patience is exhausted. " "To debug, wrap `fit()` with " "autograd's `detect_anomaly()`, e.g.\n\nwith " "torch.autograd.detect_anomaly():\n\tbrain.fit(...)" ) else: logger.warn("Patience not yet exhausted, ignoring this batch.") return False # Clip gradient norm torch.nn.utils.clip_grad_norm_( (p for p in self.modules.parameters()), self.max_grad_norm ) return True def evaluate_batch(self, batch, stage): """Evaluate one batch, override for different procedure than train. The default implementation depends on two methods being defined with a particular behavior: * ``compute_forward()`` * ``compute_objectives()`` Arguments --------- batch : list of torch.Tensors Batch of data to use for evaluation. Default implementation assumes this batch has two elements: inputs and targets. stage : Stage The stage of the experiment: Stage.VALID, Stage.TEST Returns ------- detached loss """ out = self.compute_forward(batch, stage=stage) loss = self.compute_objectives(out, batch, stage=stage) return loss.detach().cpu() def fit( self, epoch_counter, train_set, valid_set=None, progressbar=None, train_loader_kwargs={}, valid_loader_kwargs={}, ): """Iterate epochs and datasets to improve objective. Relies on the existence of multiple functions that can (or should) be overridden. The following methods are used and expected to have a certain behavior: * ``fit_batch()`` * ``evaluate_batch()`` * ``update_average()`` If the initialization was done with distributed_count > 0 and the distributed_backend is ddp, this will generally handle multiprocess logic, like splitting the training data into subsets for each device and only saving a checkpoint on the main process. Arguments --------- epoch_counter : iterable Each call should return an integer indicating the epoch count. train_set : Dataset, DataLoader A set of data to use for training. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly. valid_set : Dataset, DataLoader A set of data to use for validation. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly. train_loader_kwargs : dict Kwargs passed to `make_dataloader()` for making the train_loader (if train_set is a Dataset, not DataLoader). E.G. batch_size, num_workers. DataLoader kwargs are all valid. valid_loader_kwargs : dict Kwargs passed to `make_dataloader()` for making the valid_loader (if valid_set is a Dataset, not DataLoader). E.g., batch_size, num_workers. DataLoader kwargs are all valid. progressbar : bool Whether to display the progress of each epoch in a progressbar. """ if not isinstance(train_set, DataLoader): train_set = self.make_dataloader( train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs ) if valid_set is not None and not isinstance(valid_set, DataLoader): valid_set = self.make_dataloader( valid_set, stage=sb.Stage.VALID, ckpt_prefix=None, **valid_loader_kwargs, ) self.on_fit_start() if progressbar is None: progressbar = not self.noprogressbar # Iterate epochs for epoch in epoch_counter: # Training stage self.on_stage_start(Stage.TRAIN, epoch) self.modules.train() # Reset nonfinite count to 0 each epoch self.nonfinite_count = 0 if self.train_sampler is not None and hasattr( self.train_sampler, "set_epoch" ): self.train_sampler.set_epoch(epoch) # Time since last intra-epoch checkpoint last_ckpt_time = time.time() # Only show progressbar if requested and main_process enable = progressbar and sb.utils.distributed.if_main_process() with tqdm( train_set, initial=self.step, dynamic_ncols=True, disable=not enable, ) as t: for batch in t: self.step += 1 loss = self.fit_batch(batch) self.avg_train_loss = self.update_average( loss, self.avg_train_loss ) t.set_postfix(train_loss=self.avg_train_loss) # Debug mode only runs a few batches if self.debug and self.step == self.debug_batches: break if ( self.checkpointer is not None and self.ckpt_interval_minutes > 0 and time.time() - last_ckpt_time >= self.ckpt_interval_minutes * 60.0 ): run_on_main(self._save_intra_epoch_ckpt) last_ckpt_time = time.time() # Run train "on_stage_end" on all processes self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch) self.avg_train_loss = 0.0 self.step = 0 # Validation stage if valid_set is not None: self.on_stage_start(Stage.VALID, epoch) self.modules.eval() avg_valid_loss = 0.0 with torch.no_grad(): for batch in tqdm( valid_set, dynamic_ncols=True, disable=not enable ): self.step += 1 loss = self.evaluate_batch(batch, stage=Stage.VALID) avg_valid_loss = self.update_average( loss, avg_valid_loss ) # Debug mode only runs a few batches if self.debug and self.step == self.debug_batches: break # Only run validation "on_stage_end" on main process self.step = 0 run_on_main( self.on_stage_end, args=[Stage.VALID, avg_valid_loss, epoch], ) # Debug mode only runs a few epochs if self.debug and epoch == self.debug_epochs: break def _save_intra_epoch_ckpt(self): """Saves a CKPT with specific intra-epoch flag.""" self.checkpointer.save_and_keep_only( end_of_epoch=False, num_to_keep=1, ckpt_predicate=lambda c: INTRA_EPOCH_CKPT_FLAG in c.meta, meta={INTRA_EPOCH_CKPT_FLAG: True}, verbosity=logging.DEBUG, ) def _compile_jit(self): """Compile requested modules with ``torch.jit.script``.""" if self.jit_module_keys is None: return for name in self.jit_module_keys: if name not in self.modules: raise ValueError( "module" + name + " is not defined in your hparams file." ) module = torch.jit.script(self.modules[name]) self.modules[name] = module.to(self.device) def _wrap_distributed(self): """Wrap modules with distributed wrapper when requested.""" if not self.distributed_launch and not self.data_parallel_backend: return elif self.distributed_launch: for name, module in self.modules.items(): if any(p.requires_grad for p in module.parameters()): # for ddp, all module must run on same GPU module = SyncBatchNorm.convert_sync_batchnorm(module) module = DDP(module, device_ids=[self.device]) self.modules[name] = module else: # data_parallel_backend for name, module in self.modules.items(): if any(p.requires_grad for p in module.parameters()): # if distributed_count = -1 then use all gpus # otherwise, specify the set of gpu to use if self.data_parallel_count == -1: module = DP(module) else: module = DP( module, [i for i in range(self.data_parallel_count)], ) self.modules[name] = module def evaluate( self, test_set, max_key=None, min_key=None, progressbar=None, test_loader_kwargs={}, ): """Iterate test_set and evaluate brain performance. By default, loads the best-performing checkpoint (as recorded using the checkpointer). Arguments --------- test_set : Dataset, DataLoader If a DataLoader is given, it is iterated directly. Otherwise passed to ``self.make_dataloader()``. max_key : str Key to use for finding best checkpoint, passed to ``on_evaluate_start()``. min_key : str Key to use for finding best checkpoint, passed to ``on_evaluate_start()``. progressbar : bool Whether to display the progress in a progressbar. test_loader_kwargs : dict Kwargs passed to ``make_dataloader()`` if ``test_set`` is not a DataLoader. NOTE: ``loader_kwargs["ckpt_prefix"]`` gets automatically overwritten to ``None`` (so that the test DataLoader is not added to the checkpointer). Returns ------- average test loss """ if progressbar is None: progressbar = not self.noprogressbar if not isinstance(test_set, DataLoader): test_loader_kwargs["ckpt_prefix"] = None test_set = self.make_dataloader( test_set, Stage.TEST, **test_loader_kwargs ) self.on_evaluate_start(max_key=max_key, min_key=min_key) self.on_stage_start(Stage.TEST, epoch=None) self.modules.eval() avg_test_loss = 0.0 with torch.no_grad(): for batch in tqdm( test_set, dynamic_ncols=True, disable=not progressbar ): self.step += 1 loss = self.evaluate_batch(batch, stage=Stage.TEST) avg_test_loss = self.update_average(loss, avg_test_loss) # Debug mode only runs a few batches if self.debug and self.step == self.debug_batches: break # Only run evaluation "on_stage_end" on main process run_on_main( self.on_stage_end, args=[Stage.TEST, avg_test_loss, None] ) self.step = 0 def update_average(self, loss, avg_loss): """Update running average of the loss. Arguments --------- loss : torch.tensor detached loss, a single float value. avg_loss : float current running average. Returns ------- avg_loss : float The average loss. """ if torch.isfinite(loss): avg_loss -= avg_loss / (self.step + 1) avg_loss += float(loss) / (self.step + 1) return avg_loss @sb.utils.checkpoints.mark_as_saver def _save(self, path): save_dict = { "step": self.step, "avg_train_loss": self.avg_train_loss, } with open(path, "w") as w: w.write(yaml.dump(save_dict)) @sb.utils.checkpoints.mark_as_loader def _recover(self, path, end_of_epoch, device): del end_of_epoch del device with open(path) as f: save_dict = yaml.safe_load(f) self.step = save_dict["step"] self.avg_train_loss = save_dict["avg_train_loss"]
def main(args): utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(image_set='train', args=args) dataset_val = build_dataset(image_set='val', args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) base_ds = get_coco_api_from_dataset(coco_val) else: base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: checkpoint = torch.load(args.frozen_weights, map_location='cpu') model_without_ddp.detr.load_state_dict(checkpoint['model']) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model'], strict=False) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.eval: test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) if args.output_dir: utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") return #cab writer = SummaryWriter("runs/" + args.tb_name) best_value = 0 print("Start training, best_value is " + str(best_value)) start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) #cab for k, v in train_stats.items(): if isinstance(v, float): writer.add_scalar(f'train_{k}', v, epoch) new_value = 0 for k, v in test_stats.items(): if (isinstance(v, float)): writer.add_scalar(f'test_{k}', v, epoch) if (k == "coco_eval_bbox"): new_value = v[0] writer.add_scalar( 'Bbox Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', v[0], epoch) writer.add_scalar( 'Bbox Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ]', v[1], epoch) writer.add_scalar( 'Bbox Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ]', v[2], epoch) writer.add_scalar( 'Bbox Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', v[3], epoch) writer.add_scalar( 'Bbox Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', v[4], epoch) writer.add_scalar( 'Bbox Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', v[5], epoch) writer.add_scalar( 'Bbox Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]', v[6], epoch) writer.add_scalar( 'Bbox Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]', v[7], epoch) writer.add_scalar( 'Bbox Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', v[8], epoch) writer.add_scalar( 'Bbox Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', v[9], epoch) writer.add_scalar( 'Bbox Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', v[10], epoch) writer.add_scalar( 'Bbox Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', v[11], epoch) if (k == "coco_eval_masks"): new_value = v[0] writer.add_scalar( 'Mask Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', v[0], epoch) writer.add_scalar( 'Mask Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ]', v[1], epoch) writer.add_scalar( 'Mask Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ]', v[2], epoch) writer.add_scalar( 'Mask Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', v[3], epoch) writer.add_scalar( 'Mask Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', v[4], epoch) writer.add_scalar( 'Mask Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', v[5], epoch) writer.add_scalar( 'Mask Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ]', v[6], epoch) writer.add_scalar( 'Mask Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ]', v[7], epoch) writer.add_scalar( 'Mask Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ]', v[8], epoch) writer.add_scalar( 'Mask Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]', v[9], epoch) writer.add_scalar( 'Mask Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]', v[10], epoch) writer.add_scalar( 'Mask Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]', v[11], epoch) print("Epoch finished, best_value is " + str(best_value)) save_pth = False if best_value < new_value: best_value = new_value save_pth = True if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] # extra checkpoint before LR drop and every 100 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') if save_pth: checkpoint_paths.append(output_dir / f'best.pth') bestLog = open(output_dir / 'best_log.txt', 'w+') bestLog.write(f'Saved model at epoch {epoch:04}\n') for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) #/cab log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") # for evaluation logs if coco_evaluator is not None: (output_dir / 'eval').mkdir(exist_ok=True) if "bbox" in coco_evaluator.coco_eval: filenames = ['latest.pth'] if epoch % 50 == 0: filenames.append(f'{epoch:03}.pth') for name in filenames: torch.save(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) print(args) device = torch.device(args.device) # Fix the seed for reproducibility. seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, { "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) dataset_train = build_dataset(image_set='train', args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) # Load from pretrained DETR model. assert args.num_queries == 100, args.num_queries assert args.enc_layers == 6 and args.dec_layers == 6 assert args.backbone in ['resnet50', 'resnet101', 'swin'], args.backbone if args.backbone == 'resnet50': pretrain_model = './data/detr_coco/detr-r50-e632da11.pth' elif args.backbone == 'resnet101': pretrain_model = './data/detr_coco/detr-r101-2c7b67e5.pth' else: pretrain_model = None if pretrain_model is not None: pretrain_dict = torch.load(pretrain_model, map_location='cpu')['model'] my_model_dict = model_without_ddp.state_dict() pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in my_model_dict} my_model_dict.update(pretrain_dict) model_without_ddp.load_state_dict(my_model_dict) output_dir = Path(args.output_dir) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] # extra checkpoint before LR drop and every 10 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') if (epoch + 1) > args.lr_drop and (epoch + 1) % 10 == 0: checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): prt.init_distributed_mode(args) device = torch.device(args.device) model = SlotModel(args) print( "train model: " + f"{'use slot ' if args.use_slot else 'without slot '}" + f"{'negetive loss' if args.use_slot and args.loss_status != 1 else 'positive loss'}" ) model.to(device) model_without_ddp = model if args.thop: def freeze_layers(model): for layer in model.children(): if isinstance(layer, torch.nn.Sequential): for sub_layer in layer: sub_layer.requires_grad = False for parameter in sub_layer.parameters(): parameter.requires_grad = False else: layer.requires_grad = False for parameter in layer.parameters(): parameter.requires_grad = False def unfreeze_layers(model): for layer in model.children(): if isinstance(layer, torch.nn.Sequential): for sub_layer in layer: sub_layer.requires_grad = True for parameter in sub_layer.parameters(): parameter.requires_grad = True else: layer.requires_grad = True for parameter in layer.parameters(): parameter.requires_grad = True unfreeze_layers(model) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print(float(n_parameters) / 1000000, 'M') freeze_layers(model) model.cpu() model.eval() tl.set_backend('pytorch') input_ = torch.randn(1, 3, 260, 260) flops_list = [] params_list = [] acc_list = [] flops, params = profile(model, inputs=(input_, )) flops_list.append(flops) params_list.append(params) flops, params = clever_format([flops, params], "%.3f") print(float(n_parameters) / 1000000, 'M', params, flops) return [float(n_parameters) / 1000000, flops_list[-1] / 1000000000] if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) params = [p for p in model_without_ddp.parameters() if p.requires_grad] optimizer = torch.optim.AdamW(params, lr=args.lr) criterion = torch.nn.CrossEntropyLoss() lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_drop) dataset_train, dataset_val = select_dataset(args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoaderX(dataset_train, batch_sampler=batch_sampler_train, num_workers=args.num_workers) data_loader_val = DataLoaderX(dataset_val, args.batch_size, sampler=sampler_val, num_workers=args.num_workers) output_dir = Path(args.output_dir) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 print("Start training") start_time = time.time() log = MetricLog() record = log.record for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_one_epoch(model, data_loader_train, optimizer, device, record, epoch) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / (f"{args.dataset}_" + f"{'use_slot_' if args.use_slot else 'no_slot_'}"\ + f"{'negative_' if args.use_slot and args.loss_status != 1 else ''}"\ + f"{'for_area_size_'+str(args.lambda_value) + '_'+ str(args.slots_per_class) + '_' if args.cal_area_size else ''}" + 'checkpoint.pth')] # extra checkpoint before LR drop and every 10 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 10 == 0: checkpoint_paths.append(output_dir / (f"{args.dataset}_" + f"{'use_slot_' if args.use_slot else 'no_slot_'}"\ + f"{'negative_' if args.use_slot and args.loss_status != 1 else ''}"\ + f"{'for_area_size_'+str(args.lambda_value) + '_'+ str(args.slots_per_class) + '_' if args.cal_area_size else ''}" + f'checkpoint{epoch:04}.pth')) for checkpoint_path in checkpoint_paths: prt.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) evaluate(model, data_loader_val, device, record, epoch) log.print_metric() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) return [record["train"]["acc"][-1], record["val"]["acc"][-1]]
def main(args): utils.init_distributed_mode(args) print("git:\n {}\n".format(utils.get_sha())) if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) model, criterion, postprocessors = build_model(args) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad ] }, { "params": [ p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad ], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop, gamma=0.9) dataset_train = build_dataset(image_set='train', args=args) dataset_val = build_dataset(image_set='val', args=args) if args.distributed: sampler_train = DistributedSampler(dataset_train) sampler_val = DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) if args.dataset_file == "coco_panoptic": # We also evaluate AP during panoptic training, on original coco DS coco_val = datasets.coco.build("val", args) base_ds = get_coco_api_from_dataset(coco_val) else: base_ds = get_coco_api_from_dataset(dataset_val) if args.frozen_weights is not None: checkpoint = torch.load(args.frozen_weights, map_location='cpu') model_without_ddp.detr.load_state_dict(checkpoint['model']) output_dir = Path(args.output_dir) output_dir = output_dir / f"{args.backbone}_{args.transformer_type}" if args.output_dir: output_dir.mkdir(parents=True, exist_ok=True) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.eval: test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) if args.output_dir: utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / f'checkpoint_{epoch}.pth'] # extra checkpoint before LR drop and every 100 epochs if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: checkpoint_paths.append(output_dir / f'checkpoint{epoch}_extra.pth') for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args, }, checkpoint_path) test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir) log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") # for evaluation logs if coco_evaluator is not None: (output_dir / 'eval').mkdir(exist_ok=True) if "bbox" in coco_evaluator.coco_eval: filenames = ['latest.pth'] if epoch % 50 == 0: filenames.append(f'{epoch:03}.pth') for name in filenames: torch.save(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(rank): global args, best_prec1 args = parser.parse_args() device_id = int(os.environ.get('LOCAL_RANK', args.local_rank)) print("====rank={} device_id={} ".format(rank, device_id)) num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset( args.dataset, args.modality) full_arch_name = args.arch if args.shift: full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place) if args.temporal_pool: full_arch_name += '_tpool' args.store_name = '_'.join([ 'TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments, 'e{}'.format(args.epochs) ]) if args.pretrain != 'imagenet': args.store_name += '_{}'.format(args.pretrain) if args.lr_type != 'step': args.store_name += '_{}'.format(args.lr_type) if args.dense_sample: args.store_name += '_dense' if args.non_local > 0: args.store_name += '_nl' if args.suffix is not None: args.store_name += '_{}'.format(args.suffix) print('storing name: ' + args.store_name) torch.cuda.set_device(device_id) check_rootfolders() model = TSN(num_class, args.num_segments, args.modality, base_model=args.arch, consensus_type=args.consensus_type, dropout=args.dropout, img_feature_dim=args.img_feature_dim, partial_bn=not args.no_partialbn, pretrain=args.pretrain, is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, fc_lr5=not (args.tune_from and args.dataset in args.tune_from), temporal_pool=args.temporal_pool, non_local=args.non_local) crop_size = model.crop_size scale_size = model.scale_size input_mean = model.input_mean input_std = model.input_std policies = model.get_optim_policies() train_augmentation = model.get_augmentation( flip=False if 'something' in args.dataset or 'jester' in args.dataset else True) #model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() model = model.cuda() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device_id], output_device=device_id) optimizer = torch.optim.SGD(policies, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if args.temporal_pool: # early temporal pool so that we can load the state_dict make_temporal_pool(model.module.base_model, args.num_segments) if os.path.isfile(args.resume): print(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch']))) else: print(("=> no checkpoint found at '{}'".format(args.resume))) if args.tune_from: print(("=> fine-tuning from '{}'".format(args.tune_from))) sd = torch.load(args.tune_from) sd = sd['state_dict'] model_dict = model.state_dict() replace_dict = [] for k, v in sd.items(): if k not in model_dict and k.replace('.net', '') in model_dict: print('=> Load after remove .net: ', k) replace_dict.append((k, k.replace('.net', ''))) for k, v in model_dict.items(): if k not in sd and k.replace('.net', '') in sd: print('=> Load after adding .net: ', k) replace_dict.append((k.replace('.net', ''), k)) for k, k_new in replace_dict: sd[k_new] = sd.pop(k) keys1 = set(list(sd.keys())) keys2 = set(list(model_dict.keys())) set_diff = (keys1 - keys2) | (keys2 - keys1) print('#### Notice: keys that failed to load: {}'.format(set_diff)) if args.dataset not in args.tune_from: # new dataset print('=> New dataset, do not load fc weights') sd = {k: v for k, v in sd.items() if 'fc' not in k} if args.modality == 'Flow' and 'Flow' not in args.tune_from: sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k} model_dict.update(sd) model.load_state_dict(model_dict) if args.temporal_pool and not args.resume: make_temporal_pool(model.module.base_model, args.num_segments) cudnn.benchmark = True # Data loading code if args.modality != 'RGBDiff': normalize = GroupNormalize(input_mean, input_std) else: normalize = IdentityTransform() if args.modality == 'RGB': data_length = 1 elif args.modality in ['Flow', 'RGBDiff']: data_length = 5 train_dataset = TSNDataSet( args.root_path, args.train_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, transform=torchvision.transforms.Compose([ train_augmentation, Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample) train_sampler = DistributedSampler(train_dataset, shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True) # prevent something not % n_GPU if rank == 0: val_loader = torch.utils.data.DataLoader(TSNDataSet( args.root_path, args.val_list, num_segments=args.num_segments, new_length=data_length, modality=args.modality, image_tmpl=prefix, random_shift=False, transform=torchvision.transforms.Compose([ GroupScale(int(scale_size)), GroupCenterCrop(crop_size), Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor( div=(args.arch not in ['BNInception', 'InceptionV3'])), normalize, ]), dense_sample=args.dense_sample), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and optimizer if args.loss_type == 'nll': criterion = torch.nn.CrossEntropyLoss().cuda() else: raise ValueError("Unknown loss type") for group in policies: print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) if args.evaluate: validate(val_loader, model, criterion, 0) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) scaler = torch.cuda.amp.GradScaler(args.amp) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps) train_sampler.set_epoch(epoch) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer, scaler, args.batch_size, args.amp, rank) #train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer) # evaluate on validation set if rank == 0 and ((epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1): prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch) output_best = 'Best Prec@1: %.3f\n' % (best_prec1) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, is_best)