Beispiel #1
0
    def _get_sampler(train_set, test_set, val_set, train_sampler, test_sampler,
                     val_sampler, start_epoch):
        if train_sampler is None:
            if is_distributed():
                train_sampler = DistributedSampler(
                    train_set,
                    num_replicas=get_world_size(),
                    rank=get_global_rank())
                train_sampler.set_epoch(start_epoch)
            else:
                train_sampler = RandomSampler(train_set, True)
        else:
            train_sampler = train_sampler(train_set)

        if test_sampler is None:
            if is_distributed():
                test_sampler = DistributedSampler(
                    test_set,
                    num_replicas=get_world_size(),
                    rank=get_global_rank())
        else:
            test_sampler = test_sampler(test_set)

        if val_set is not None:
            if val_sampler is None and is_distributed():
                val_sampler = DistributedSampler(val_set,
                                                 num_replicas=get_world_size(),
                                                 rank=get_global_rank())
                val_sampler.set_epoch(start_epoch)
            elif val_sampler is not None:
                val_sampler = val_sampler(val_set)

        return train_sampler, test_sampler, val_sampler
Beispiel #2
0
def get_ddp_sampler(dataset: Dataset, epoch: int):
    """
    This function will create a DistributedSampler if DDP is initialized,
    and will just return None if DDP is not initialized.
    """
    if is_initialized():
        sampler = DistributedSampler(dataset)
        sampler.set_epoch(epoch)
    else:
        sampler = None
    return sampler
def build_data_loader(
    image_path: Union[str, Path],
    config: dict,
    uses_absolute_paths: bool,
    shuffle_off: bool = False,
    dataset_class: Type[AutoencoderDataset] = AutoencoderDataset
) -> DataLoader:
    transform_list = [
        transforms.Resize((config['image_size'], config['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ) * config['input_dim'],
                             (0.5, ) * config['input_dim'])
    ]
    transform_list = transforms.Compose(transform_list)

    dataset = dataset_class(
        image_path,
        root=os.path.dirname(image_path) if not uses_absolute_paths else None,
        transforms=transform_list,
        loader=resilient_loader,
    )

    sampler = None
    if get_world_size() > 1:
        sampler = DistributedSampler(dataset, shuffle=not shuffle_off)
        sampler.set_epoch(get_rank())

    if shuffle_off:
        shuffle = False
    else:
        shuffle = sampler is None

    loader = DataLoader(
        dataset,
        config['batch_size'],
        shuffle=shuffle,
        drop_last=True,
        sampler=sampler,
    )
    return loader
Beispiel #4
0
class BalancedBatchSampler(Sampler):
    def __init__(
        self,
        dataset,
        batch_size,
        num_replicas,
        rank,
        device,
        mode="atoms",
        shuffle=True,
        drop_last=False,
        force_balancing=False,
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_replicas = num_replicas
        self.rank = rank
        self.device = device
        self.mode = mode.lower()
        self.shuffle = shuffle
        self.drop_last = drop_last

        self.balance_batches = self.num_replicas > 1
        if self.balance_batches:
            if (
                not hasattr(dataset, "metadata_path")
                or not dataset.metadata_path.is_file()
            ):
                if force_balancing:
                    logging.warning(
                        f"No metadata file found at '{dataset.metadata_path}'. "
                        "BalancedBatchSampler has to load the data to "
                        "determine batch sizes, which incurs "
                        "significant overhead!"
                    )
                    self.sizes = None
                else:
                    logging.warning(
                        f"No metadata file found at '{dataset.metadata_path}'. "
                        "Batches will not be balanced, "
                        "which can incur significant overhead!"
                    )
                    self.balance_batches = False
                    self.sizes = None
            else:
                if self.mode == "atoms":
                    self.sizes = np.load(dataset.metadata_path)["natoms"]
                elif self.mode == "neighbors":
                    self.sizes = np.load(dataset.metadata_path)["neighbors"]
                else:
                    raise NotImplementedError(
                        f"Unknown load balancing mode: {self.mode}"
                    )
        else:
            self.sizes = None

        self.single_sampler = DistributedSampler(
            dataset,
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
            drop_last=drop_last,
        )
        self.batch_sampler = BatchSampler(
            self.single_sampler,
            batch_size,
            drop_last=drop_last,
        )

    def __len__(self):
        return len(self.batch_sampler)

    def set_epoch(self, epoch):
        self.single_sampler.set_epoch(epoch)

    def __iter__(self):
        for batch_idx in self.batch_sampler:
            if self.balance_batches:
                if self.sizes is None:
                    # Unfortunately, we need to load the data to know the image sizes
                    data_list = [self.dataset[idx] for idx in batch_idx]

                    if self.mode == "atoms":
                        sizes = [data.num_nodes for data in data_list]
                    elif self.mode == "neighbors":
                        sizes = [
                            data.edge_index.shape[1] for data in data_list
                        ]
                    else:
                        raise NotImplementedError(
                            f"Unknown load balancing mode: {self.mode}"
                        )
                else:
                    sizes = [self.sizes[idx] for idx in batch_idx]

                idx_sizes = torch.stack(
                    [torch.tensor(batch_idx), torch.tensor(sizes)]
                )
                idx_sizes_all = distutils.all_gather(
                    idx_sizes, device=self.device
                )
                idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu()
                idx_all = idx_sizes_all[0]
                sizes_all = idx_sizes_all[1]

                local_idx_balanced = balanced_partition(
                    sizes_all.numpy(), num_parts=self.num_replicas
                )
                # Since DistributedSampler pads the last batch
                # this should always have an entry for each replica.
                yield idx_all[local_idx_balanced[self.rank]]
            else:
                yield batch_idx
Beispiel #5
0
def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    # IPython.embed()
    # IPython.embed()
    # os.system("sudo chmod -R 777 /home/shuxuang/.cache/")
    model, criterion, postprocessors = build_model(
        args)  # use the same model as detr paper on coco
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    # dataset_train = build_dataset(image_set='train', args=args)
    # dataset_val = build_dataset(image_set='val', args=args)
    # modify the dataset from coco to nvdata
    # home_dir = os.environ["HOME"]
    # dataset_train_ = build_nvdataset(dataset_root=[
    #                                     os.path.join(os.environ["HOME"],'datasets/annotation_sql_nvidia'),
    #                                     os.path.join(os.environ["HOME"], 'datasets/frames_nvidia')],
    #                                 mode='train')
    # dataset_val = build_nvdataset(dataset_root=[
    #                                 os.path.join(os.environ["HOME"],'datasets/test'),
    #                                 os.path.join(os.environ["HOME"], 'datasets/frames_nvidia')],
    #                               mode='test')
    # indices_50k =np.load(os.path.join(os.environ["HOME"],'datasets/id_1_criterion_Max_SSD_num_labels_50000.npy'))

    dataset_train = build_nvdataset(
        dataset_root=[args.dataset_root_sql, args.dataset_root_img],
        mode='train',
        camera=args.camera)
    dataset_val = build_nvdataset(
        dataset_root=[args.dataset_root_test, args.dataset_root_test],
        mode='test',
        camera=args.camera)
    if args.root_indices is not None:
        indices_50k = np.load(os.path.join(args.root_indices))
        # indices_50k =np.load(os.path.join(os.environ["HOME"],'datasets/id_1_criterion_Max_SSD_num_labels_50000.npy'))
        dataset_train = Subset(dataset_train, indices_50k)
    # IPython.embed()
    print("Train samples: %d" % (len(dataset_train)))

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn,
                                   num_workers=args.num_workers)
    data_loader_val = DataLoader(dataset_val,
                                 args.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 collate_fn=utils.collate_fn,
                                 num_workers=args.num_workers)

    # if args.dataset_file == "coco_panoptic":
    #     # We also evaluate AP during panoptic training, on original coco DS
    #     coco_val = datasets.coco.build("val", args)
    #     base_ds = get_coco_api_from_dataset(coco_val)
    # elif args.dataset_file == "nvdata":
    #     coco_val = datasets.coco.build("val", args)
    #     base_ds = get_coco_api_from_dataset(coco_val)
    # else:
    #     base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model'])

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    # if args.eval:
    #     test_stats, coco_evaluator = evaluate_nvdata(model, criterion, postprocessors,
    #                                           data_loader_val, base_ds, device, args.output_dir)
    #     if args.output_dir:
    #         utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
    #     return

    # if args.eval:
    #     evaluate(model, dataset_val, postprocessors, device)

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch,
                                      args.clip_max_norm)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 50 == 0:
                checkpoint_paths.append(output_dir /
                                        f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)

        # test_stats, coco_evaluator = evaluate_nvdata(
        #     model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
        # )

        # log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
        #              **{f'test_{k}': v for k, v in test_stats.items()},
        #              'epoch': epoch,
        #              'n_parameters': n_parameters}

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

            # for evaluation logs
            # if coco_evaluator is not None:
            #     (output_dir / 'eval').mkdir(exist_ok=True)
            #     if "bbox" in coco_evaluator.coco_eval:
            #         filenames = ['latest.pth']
            #         if epoch % 50 == 0:
            #             filenames.append(f'{epoch:03}.pth')
            #         for name in filenames:
            #             torch.save(coco_evaluator.coco_eval["bbox"].eval,
            #                        output_dir / "eval" / name)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Beispiel #6
0
def train(rank, a, h):
    if h.num_gpus > 1:
        init_process_group(backend=h.dist_config['dist_backend'],
                           init_method=h.dist_config['dist_url'],
                           world_size=h.dist_config['world_size'] * h.num_gpus,
                           rank=rank)

    torch.cuda.manual_seed(h.seed)
    torch.cuda.set_device(rank)
    device = torch.device('cuda:{:d}'.format(rank))

    generator = Generator(h).to(device)
    mpd = MultiPeriodDiscriminator().to(device)
    msd = MultiScaleDiscriminator().to(device)

    if rank == 0:
        print(generator)
        os.makedirs(a.checkpoint_path, exist_ok=True)
        print("checkpoints directory : ", a.checkpoint_path)

    if os.path.isdir(a.checkpoint_path):
        cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
        cp_do = scan_checkpoint(a.checkpoint_path, 'do_')

    steps = 0
    if cp_g is None or cp_do is None:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_g = load_checkpoint(cp_g, device)
        state_dict_do = load_checkpoint(cp_do, device)
        generator.load_state_dict(state_dict_g['generator'])
        mpd.load_state_dict(state_dict_do['mpd'])
        msd.load_state_dict(state_dict_do['msd'])
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if h.num_gpus > 1:
        generator = DistributedDataParallel(generator,
                                            device_ids=[rank]).to(device)
        mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
        msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)

    optim_g = torch.optim.AdamW(generator.parameters(),
                                h.learning_rate,
                                betas=[h.adam_b1, h.adam_b2])
    optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(),
                                                mpd.parameters()),
                                h.learning_rate,
                                betas=[h.adam_b1, h.adam_b2])

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g,
                                                         gamma=h.lr_decay,
                                                         last_epoch=last_epoch)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d,
                                                         gamma=h.lr_decay,
                                                         last_epoch=last_epoch)

    training_filelist, validation_filelist = get_dataset_filelist(a)

    trainset = MelDataset(training_filelist,
                          h.segment_size,
                          h.n_fft,
                          h.num_mels,
                          h.hop_size,
                          h.win_size,
                          h.sampling_rate,
                          h.fmin,
                          h.fmax,
                          n_cache_reuse=0,
                          shuffle=False if h.num_gpus > 1 else True,
                          fmax_loss=h.fmax_for_loss,
                          device=device,
                          fine_tuning=a.fine_tuning,
                          base_mels_path=a.input_mels_dir)

    train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None

    train_loader = DataLoader(trainset,
                              num_workers=h.num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=h.batch_size,
                              pin_memory=True,
                              drop_last=True)

    if rank == 0:
        validset = MelDataset(validation_filelist,
                              h.segment_size,
                              h.n_fft,
                              h.num_mels,
                              h.hop_size,
                              h.win_size,
                              h.sampling_rate,
                              h.fmin,
                              h.fmax,
                              False,
                              False,
                              n_cache_reuse=0,
                              fmax_loss=h.fmax_for_loss,
                              device=device,
                              fine_tuning=a.fine_tuning,
                              base_mels_path=a.input_mels_dir)
        validation_loader = DataLoader(validset,
                                       num_workers=1,
                                       shuffle=False,
                                       sampler=None,
                                       batch_size=1,
                                       pin_memory=True,
                                       drop_last=True)

        sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))

    generator.train()
    mpd.train()
    msd.train()
    for epoch in range(max(0, last_epoch), a.training_epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch + 1))

        if h.num_gpus > 1:
            train_sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            if rank == 0:
                start_b = time.time()
            x, y, _, y_mel = batch
            x = torch.autograd.Variable(x.to(device, non_blocking=True))
            y = torch.autograd.Variable(y.to(device, non_blocking=True))
            y_mel = torch.autograd.Variable(y_mel.to(device,
                                                     non_blocking=True))
            y = y.unsqueeze(1)

            y_g_hat = generator(x)
            y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft,
                                          h.num_mels, h.sampling_rate,
                                          h.hop_size, h.win_size, h.fmin,
                                          h.fmax_for_loss)

            optim_d.zero_grad()

            # MPD
            y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
            loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
                y_df_hat_r, y_df_hat_g)

            # MSD
            y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
            loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
                y_ds_hat_r, y_ds_hat_g)

            loss_disc_all = loss_disc_s + loss_disc_f

            loss_disc_all.backward()
            optim_d.step()

            # Generator
            optim_g.zero_grad()

            # L1 Mel-Spectrogram Loss
            loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45

            y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
            y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
            loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
            loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
            loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
            loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
            loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel

            loss_gen_all.backward()
            optim_g.step()

            if rank == 0:
                # STDOUT logging
                if steps % a.stdout_interval == 0:
                    with torch.no_grad():
                        mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()

                    print(
                        'Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'
                        .format(steps, loss_gen_all, mel_error,
                                time.time() - start_b))

                # checkpointing
                if steps % a.checkpoint_interval == 0 and steps != 0:
                    checkpoint_path = "{}/g_{:08d}".format(
                        a.checkpoint_path, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'generator': (generator.module if h.num_gpus > 1
                                          else generator).state_dict()
                        })
                    checkpoint_path = "{}/do_{:08d}".format(
                        a.checkpoint_path, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'mpd': (mpd.module
                                    if h.num_gpus > 1 else mpd).state_dict(),
                            'msd': (msd.module
                                    if h.num_gpus > 1 else msd).state_dict(),
                            'optim_g':
                            optim_g.state_dict(),
                            'optim_d':
                            optim_d.state_dict(),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        })

                # Tensorboard summary logging
                if steps % a.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", loss_gen_all,
                                  steps)
                    sw.add_scalar("training/mel_spec_error", mel_error, steps)

                # Validation
                if steps % a.validation_interval == 0:  # and steps != 0:
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    with torch.no_grad():
                        for j, batch in enumerate(validation_loader):
                            x, y, _, y_mel = batch
                            y_g_hat = generator(x.to(device))
                            y_mel = torch.autograd.Variable(
                                y_mel.to(device, non_blocking=True))
                            y_g_hat_mel = mel_spectrogram(
                                y_g_hat.squeeze(1), h.n_fft, h.num_mels,
                                h.sampling_rate, h.hop_size, h.win_size,
                                h.fmin, h.fmax_for_loss)
                            val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()

                            if j <= 4:
                                if steps == 0:
                                    sw.add_audio('gt/y_{}'.format(j), y[0],
                                                 steps, h.sampling_rate)
                                    sw.add_figure('gt/y_spec_{}'.format(j),
                                                  plot_spectrogram(x[0]),
                                                  steps)

                                sw.add_audio('generated/y_hat_{}'.format(j),
                                             y_g_hat[0], steps,
                                             h.sampling_rate)
                                y_hat_spec = mel_spectrogram(
                                    y_g_hat.squeeze(1), h.n_fft, h.num_mels,
                                    h.sampling_rate, h.hop_size, h.win_size,
                                    h.fmin, h.fmax)
                                sw.add_figure(
                                    'generated/y_hat_spec_{}'.format(j),
                                    plot_spectrogram(
                                        y_hat_spec.squeeze(0).cpu().numpy()),
                                    steps)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/mel_spec_error", val_err,
                                      steps)

                    generator.train()

            steps += 1

        scheduler_g.step()
        scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
Beispiel #7
0
def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))
    wandb.init(project="qpic-project",
               entity="sangbaeklee",
               group="experiment_qpic")
    wandb.config = {
        "learning_rate": args.lr,
        "epochs": args.epochs,
        "batch_size": args.batch_size,
    }

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)
    wandb.watch(model)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = build_dataset(image_set='val', args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn,
                                   num_workers=args.num_workers)
    data_loader_val = DataLoader(dataset_val,
                                 args.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 collate_fn=utils.collate_fn,
                                 num_workers=args.num_workers)

    if not args.hoi:
        if args.dataset_file == "coco_panoptic":
            # We also evaluate AP during panoptic training, on original coco DS
            coco_val = datasets.coco.build("val", args)
            base_ds = get_coco_api_from_dataset(coco_val)
        else:
            base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model'])

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
    elif args.pretrained:
        checkpoint = torch.load(args.pretrained, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'], strict=False)

    if args.eval:
        if args.hoi:
            test_stats = evaluate_hoi(args.dataset_file, model, postprocessors,
                                      data_loader_val,
                                      args.subject_category_id, device)
            return
        else:
            test_stats, coco_evaluator = evaluate(model, criterion,
                                                  postprocessors,
                                                  data_loader_val, base_ds,
                                                  device, args.output_dir)
            if args.output_dir:
                utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval,
                                     output_dir / "eval.pth")
            return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch,
                                      args.clip_max_norm)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
                checkpoint_paths.append(output_dir /
                                        f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)

        if args.hoi:
            test_stats = evaluate_hoi(args.dataset_file, model, postprocessors,
                                      data_loader_val,
                                      args.subject_category_id, device)
            coco_evaluator = None
        else:
            test_stats, coco_evaluator = evaluate(model, criterion,
                                                  postprocessors,
                                                  data_loader_val, base_ds,
                                                  device, args.output_dir)

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }
        #import pdb; pdb.set_trace()
        if args.dataset_file == 'hico':
            wandb.log({
                "loss": train_stats['loss'],
                "mAP": test_stats['mAP'],
                "mAP rare": test_stats['mAP rare'],
                "mAP non-rare": test_stats['mAP non-rare'],
                "mean max recall": test_stats['mean max recall']
            })
        elif args.dataset_file == 'vcoco':
            wandb.log({
                "mAP_all": test_stats['mAP_all'],
                "mAP_thesis": test_stats['mAP_thesis'],
                "AP_hold_obj": test_stats['AP_hold_obj'],
                "AP_stand": test_stats['AP_stand'],
                "AP_sit_instr": test_stats['AP_sit_instr'],
                "AP_ride_instr": test_stats['AP_ride_instr'],
                "AP_walk": test_stats['AP_walk'],
                "AP_look_obj": test_stats['AP_look_obj'],
                "AP_hit_instr": test_stats['AP_hit_instr'],
                "AP_hit_obj": test_stats['AP_hit_obj'],
                "AP_eat_obj": test_stats['AP_eat_obj'],
                "AP_eat_instr": test_stats['AP_eat_instr'],
                "AP_jump_instr": test_stats['AP_jump_instr'],
                "AP_lay_instr": test_stats['AP_lay_instr'],
                "AP_talk_on_phone_instr": test_stats['AP_talk_on_phone_instr'],
                "AP_carry_obj": test_stats['AP_carry_obj'],
                "AP_throw_obj": test_stats['AP_throw_obj'],
                "AP_catch_obj": test_stats['AP_catch_obj'],
                "AP_cut_instr": test_stats['AP_cut_instr'],
                "AP_cut_obj": test_stats['AP_cut_obj'],
                "AP_run": test_stats['AP_run'],
                "AP_work_on_computer_instr": test_stats['AP_work_on_computer_instr'],
                "AP_ski_instr": test_stats['AP_ski_instr'],
                "AP_surf_instr": test_stats['AP_surf_instr'],
                "AP_skateboard_instr": test_stats['AP_skateboard_instr'],
                "AP_smile": test_stats['AP_smile'],
                "AP_drink_instr": test_stats['AP_drink_instr'],
                "AP_kick_obj": test_stats['AP_kick_obj'],
                "AP_point_instr": test_stats['AP_point_instr'],
                "AP_read_obj": test_stats['AP_read_obj'],
                "AP_snowboard_instr": test_stats['AP_snowboard_instr'],\
                "loss" : train_stats['loss']
            })
        else:
            continue

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

            # for evaluation logs
            if coco_evaluator is not None:
                (output_dir / 'eval').mkdir(exist_ok=True)
                if "bbox" in coco_evaluator.coco_eval:
                    filenames = ['latest.pth']
                    if epoch % 50 == 0:
                        filenames.append(f'{epoch:03}.pth')
                    for name in filenames:
                        torch.save(coco_evaluator.coco_eval["bbox"].eval,
                                   output_dir / "eval" / name)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
def main(args):
    bz = args.batch_size
    lr = args.lr

    if args.cuda:
        if torch.cuda.device_count() >= 1:
            utils.init_distributed_mode(args)
        device = torch.device(args.device)
    else:
        device = torch.device('cpu')

    # fix the seed for reproducibility
    if args.cuda:
        seed = args.seed + utils.get_rank()
    else:
        seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # set up model
    model, criterion, postprocessors = build_model(args)

    model_without_ddp = model
    if args.cuda and args.distributed:
        if args.mp:
            model = torch.nn.parallel.DistributedDataParallel(model)
        else:
            model = torch.nn.parallel.DistributedDataParallel(
                model.to(args.gpu),
                device_ids=[args.gpu],
                find_unused_parameters=True)

        model_without_ddp = model.module
    elif args.cuda:
        model = model.to(device)

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    # set up model training
    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "joiner" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "joiner" in n and p.requires_grad
            ],
            "lr":
            args.lr_joiner,
        },
    ]

    # datasets build
    dataset_train = build_dataset(mode="training", args=args)
    dataset_test = build_dataset(mode="testing", args=args)

    if args.cuda and args.distributed:
        sampler_train = DistributedSampler(dataset_train, shuffle=False)
        sampler_test = DistributedSampler(dataset_test, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_test = torch.utils.data.SequentialSampler(dataset_test)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn,
                                   num_workers=args.num_workers)
    data_loader_test = DataLoader(dataset_test,
                                  1,
                                  sampler=sampler_test,
                                  drop_last=False,
                                  collate_fn=utils.collate_fn,
                                  num_workers=args.num_workers)

    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    # output and checkpoints directory
    checkpoint_dir = args.output_dir
    if not os.path.exists(checkpoint_dir):
        try:
            os.makedirs(checkpoint_dir)
        except OSError:
            pass

    if args.resume:
        checkpoint = Path(args.resume)
        assert checkpoint.exists()

        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    print("Start Training")
    start_time = time.time()
    optimizer.zero_grad()
    for epoch in range(args.start_epoch, args.epochs):
        if args.cuda and args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(epoch, args.clip_max_norm, model,
                                      criterion, data_loader_train, optimizer,
                                      lr_scheduler, device)

        if args.output_dir:
            checkpoint_dir = Path(checkpoint_dir)
            checkpoint_paths = [checkpoint_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (
                    epoch + 1) % args.save_checkpoint_every == 0:
                checkpoint_paths.append(checkpoint_dir /
                                        f'checkpoint{epoch:05}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)

        if (epoch + 1) % args.eval_interval == 0:
            # evaluation
            test_stats = evaluate(epoch, model, criterion, postprocessors,
                                  data_loader_test, args.output_dir,
                                  args.dataset, device)

            log_stats = {
                **{'train_' + str(k): v
                   for k, v in train_stats.items()},
                **{'test_' + str(k): v
                   for k, v in test_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }

            if args.output_dir and utils.is_main_process():
                with (checkpoint_dir / 'log.json').open("a") as f:
                    f.write(json.dumps(log_stats) + "\n")

        lr_scheduler.step()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Beispiel #9
0
def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    # no validation ground truth for ytvos dataset
    dataset_train = build_dataset(image_set='train', args=args)
    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn,
                                   num_workers=args.num_workers)

    output_dir = Path(args.output_dir)

    # load coco pretrained weight
    checkpoint = torch.load(args.pretrained_weights,
                            map_location='cpu')['model']
    del checkpoint["vistr.class_embed.weight"]
    del checkpoint["vistr.class_embed.bias"]
    del checkpoint["vistr.query_embed.weight"]
    model.module.load_state_dict(checkpoint, strict=False)

    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch,
                                      args.clip_max_norm)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 1 == 0:
                checkpoint_paths.append(output_dir /
                                        f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Beispiel #10
0
class EnergyTrainer(BaseTrainer):
    """
    Trainer class for the Initial Structure to Relaxed Energy (IS2RE) task.

    .. note::

        Examples of configurations for task, model, dataset and optimizer
        can be found in `configs/ocp_is2re <https://github.com/Open-Catalyst-Project/baselines/tree/master/configs/ocp_is2re/>`_.


    Args:
        task (dict): Task configuration.
        model (dict): Model configuration.
        dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset.
        optimizer (dict): Optimizer configuration.
        identifier (str): Experiment identifier that is appended to log directory.
        run_dir (str, optional): Path to the run directory where logs are to be saved.
            (default: :obj:`None`)
        is_debug (bool, optional): Run in debug mode.
            (default: :obj:`False`)
        is_vis (bool, optional): Run in debug mode.
            (default: :obj:`False`)
        print_every (int, optional): Frequency of printing logs.
            (default: :obj:`100`)
        seed (int, optional): Random number seed.
            (default: :obj:`None`)
        logger (str, optional): Type of logger to be used.
            (default: :obj:`tensorboard`)
        local_rank (int, optional): Local rank of the process, only applicable for distributed training.
            (default: :obj:`0`)
        amp (bool, optional): Run using automatic mixed precision.
            (default: :obj:`False`)
    """
    def __init__(
        self,
        task,
        model,
        dataset,
        optimizer,
        identifier,
        run_dir=None,
        is_debug=False,
        is_vis=False,
        print_every=100,
        seed=None,
        logger="tensorboard",
        local_rank=0,
        amp=False,
        cpu=False,
    ):
        super().__init__(
            task=task,
            model=model,
            dataset=dataset,
            optimizer=optimizer,
            identifier=identifier,
            run_dir=run_dir,
            is_debug=is_debug,
            is_vis=is_vis,
            print_every=print_every,
            seed=seed,
            logger=logger,
            local_rank=local_rank,
            amp=amp,
            cpu=cpu,
            name="is2re",
        )

    def load_task(self):
        assert (self.config["task"]["dataset"] == "single_point_lmdb"
                ), "EnergyTrainer requires single_point_lmdb dataset"

        print("### Loading dataset: {}".format(self.config["task"]["dataset"]))

        self.parallel_collater = ParallelCollater(
            1 if not self.cpu else 0,
            self.config["model_attributes"].get("otf_graph", False),
        )

        self.train_dataset = registry.get_dataset_class(
            self.config["task"]["dataset"])(self.config["dataset"])

        self.train_sampler = DistributedSampler(
            self.train_dataset,
            num_replicas=distutils.get_world_size(),
            rank=distutils.get_rank(),
            shuffle=True,
        )
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config["optim"]["batch_size"],
            collate_fn=self.parallel_collater,
            num_workers=self.config["optim"]["num_workers"],
            pin_memory=True,
            sampler=self.train_sampler,
        )

        self.val_loader = self.test_loader = None
        self.val_sampler = None

        if "val_dataset" in self.config:
            self.val_dataset = registry.get_dataset_class(
                self.config["task"]["dataset"])(self.config["val_dataset"])
            self.val_sampler = DistributedSampler(
                self.val_dataset,
                num_replicas=distutils.get_world_size(),
                rank=distutils.get_rank(),
                shuffle=False,
            )
            self.val_loader = DataLoader(
                self.val_dataset,
                self.config["optim"].get("eval_batch_size", 64),
                collate_fn=self.parallel_collater,
                num_workers=self.config["optim"]["num_workers"],
                pin_memory=True,
                sampler=self.val_sampler,
            )
        if "test_dataset" in self.config:
            self.test_dataset = registry.get_dataset_class(
                self.config["task"]["dataset"])(self.config["test_dataset"])
            self.test_sampler = DistributedSampler(
                self.test_dataset,
                num_replicas=distutils.get_world_size(),
                rank=distutils.get_rank(),
                shuffle=False,
            )
            self.test_loader = DataLoader(
                self.test_dataset,
                self.config["optim"].get("eval_batch_size", 64),
                collate_fn=self.parallel_collater,
                num_workers=self.config["optim"]["num_workers"],
                pin_memory=True,
                sampler=self.test_sampler,
            )

        self.num_targets = 1

        # Normalizer for the dataset.
        # Compute mean, std of training set labels.
        self.normalizers = {}
        if self.config["dataset"].get("normalize_labels", False):
            if "target_mean" in self.config["dataset"]:
                self.normalizers["target"] = Normalizer(
                    mean=self.config["dataset"]["target_mean"],
                    std=self.config["dataset"]["target_std"],
                    device=self.device,
                )
            else:
                raise NotImplementedError

    def predict(self, loader, results_file=None, disable_tqdm=False):
        if distutils.is_master() and not disable_tqdm:
            print("### Predicting on test.")
        assert isinstance(loader, torch.utils.data.dataloader.DataLoader)
        rank = distutils.get_rank()

        self.model.eval()
        if self.normalizers is not None and "target" in self.normalizers:
            self.normalizers["target"].to(self.device)
        predictions = {"id": [], "energy": []}

        for i, batch in tqdm(
                enumerate(loader),
                total=len(loader),
                position=rank,
                desc="device {}".format(rank),
                disable=disable_tqdm,
        ):
            with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                out = self._forward(batch)

            if self.normalizers is not None and "target" in self.normalizers:
                out["energy"] = self.normalizers["target"].denorm(
                    out["energy"])
            predictions["id"].extend([str(i) for i in batch[0].sid.tolist()])
            predictions["energy"].extend(out["energy"].tolist())

        self.save_results(predictions, results_file, keys=["energy"])
        return predictions

    def train(self):
        self.best_val_mae = 1e9

        start_epoch = self.start_step // len(self.train_loader)
        for epoch in range(start_epoch, self.config["optim"]["max_epochs"]):
            self.train_sampler.set_epoch(epoch)
            self.model.train()

            skip_steps = 0
            if epoch == start_epoch and start_epoch > 0:
                skip_steps = start_epoch % len(self.train_loader)
            train_loader_iter = iter(self.train_loader)

            for i in range(skip_steps, len(self.train_loader)):
                batch = next(train_loader_iter)
                # Forward, loss, backward.
                with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                    out = self._forward(batch)
                    loss = self._compute_loss(out, batch)
                loss = self.scaler.scale(loss) if self.scaler else loss
                self._backward(loss)
                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    metrics={},
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Print metrics, make plots.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update(
                    {"epoch": epoch + (i + 1) / len(self.train_loader)})
                if (i % self.config["cmd"]["print_every"] == 0
                        and distutils.is_master()):
                    log_str = [
                        "{}: {:.4f}".format(k, v) for k, v in log_dict.items()
                    ]
                    print(", ".join(log_str))

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=epoch * len(self.train_loader) + i + 1,
                        split="train",
                    )

                if self.update_lr_on_step:
                    self.scheduler.step()

            if not self.update_lr_on_step:
                self.scheduler.step()

            torch.cuda.empty_cache()

            if self.val_loader is not None:
                val_metrics = self.validate(split="val", epoch=epoch)
                if (val_metrics[self.evaluator.task_primary_metric[self.name]]
                    ["metric"] < self.best_val_mae):
                    self.best_val_mae = val_metrics[
                        self.evaluator.task_primary_metric[
                            self.name]]["metric"]
                    current_step = (epoch + 1) * len(self.train_loader)
                    self.save(epoch + 1, current_step, val_metrics)
                    if self.test_loader is not None:
                        self.predict(
                            self.test_loader,
                            results_file="predictions",
                            disable_tqdm=False,
                        )
            else:
                current_step = (epoch + 1) * len(self.train_loader)
                self.save(epoch + 1, current_step, self.metrics)

        self.train_dataset.close_db()
        if "val_dataset" in self.config:
            self.val_dataset.close_db()
        if "test_dataset" in self.config:
            self.test_dataset.close_db()

    def _forward(self, batch_list):
        output = self.model(batch_list)

        if output.shape[-1] == 1:
            output = output.view(-1)

        return {
            "energy": output,
        }

    def _compute_loss(self, out, batch_list):
        energy_target = torch.cat(
            [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0)

        if self.config["dataset"].get("normalize_labels", False):
            target_normed = self.normalizers["target"].norm(energy_target)
        else:
            target_normed = energy_target

        loss = self.criterion(out["energy"], target_normed)
        return loss

    def _compute_metrics(self, out, batch_list, evaluator, metrics={}):
        energy_target = torch.cat(
            [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0)

        if self.config["dataset"].get("normalize_labels", False):
            out["energy"] = self.normalizers["target"].denorm(out["energy"])

        metrics = evaluator.eval(
            out,
            {"energy": energy_target},
            prev_metrics=metrics,
        )

        return metrics
Beispiel #11
0
def main(args):
    # utils.init_distributed_mode(args)

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print("number of params:", n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    dataset_train = build_dataset(image_set="train", args=args)
    dataset_val = build_dataset(image_set="val", args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(
        dataset_train,
        batch_sampler=batch_sampler_train,
        collate_fn=utils.collate_fn,
        num_workers=args.num_workers,
    )
    data_loader_val = DataLoader(
        dataset_val,
        args.batch_size,
        sampler=sampler_val,
        drop_last=False,
        collate_fn=utils.collate_fn,
        num_workers=args.num_workers,
    )

    if args.dataset_file == "coco_panoptic":
        # We also evaluate AP during panoptic training, on original coco DS
        coco_val = datasets.coco.build("val", args)
        base_ds = get_coco_api_from_dataset(coco_val)
    else:
        base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location="cpu")
        model_without_ddp.detr.load_state_dict(checkpoint["model"])

    if args.resume:
        if args.resume.startswith("https"):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location="cpu",
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        if (not args.eval and "optimizer" in checkpoint
                and "lr_scheduler" in checkpoint and "epoch" in checkpoint):
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            args.start_epoch = checkpoint["epoch"] + 1

    if args.eval:
        test_stats, coco_evaluator = evaluate(
            model,
            criterion,
            postprocessors,
            data_loader_val,
            base_ds,
            device,
            args.output_dir,
        )
        if args.output_dir:
            with PathManager.open(os.path.join(args.output_dir, "eval.pth"),
                                  "wb") as f:
                utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, f)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            args.clip_max_norm,
        )
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [
            ]  # os.path.join(args.output_dir, 'checkpoint.pth')]
            # extra checkpoint before LR drop and every 10 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 10 == 0:
                checkpoint_paths.append(
                    os.path.join(args.output_dir, f"checkpoint{epoch:04}.pth"))
            for checkpoint_path in checkpoint_paths:
                with PathManager.open(checkpoint_path, "wb") as f:
                    if args.gpu == 0 and args.machine_rank == 0:
                        utils.save_on_master(
                            {
                                "model": model_without_ddp.state_dict(),
                                "optimizer": optimizer.state_dict(),
                                "lr_scheduler": lr_scheduler.state_dict(),
                                "epoch": epoch,
                                "args": args,
                            },
                            f,
                        )

        test_stats, coco_evaluator = evaluate(
            model,
            criterion,
            postprocessors,
            data_loader_val,
            base_ds,
            device,
            args.output_dir,
        )

        log_stats = {
            **{f"train_{k}": v
               for k, v in train_stats.items()},
            **{f"test_{k}": v
               for k, v in test_stats.items()},
            "epoch": epoch,
            "n_parameters": n_parameters,
        }

        if args.output_dir and utils.is_main_process():
            with PathManager.open(os.path.join(args.output_dir, "log.txt"),
                                  "w") as f:
                f.write(json.dumps(log_stats) + "\n")

            # for evaluation logs
            if coco_evaluator is not None:
                PathManager.mkdirs(os.path.join(args.output_dir, "eval"))
                if "bbox" in coco_evaluator.coco_eval:
                    filenames = ["latest.pth"]
                    if epoch % 50 == 0:
                        filenames.append(f"{epoch:03}.pth")
                    for name in filenames:
                        with PathManager.open(
                                os.path.join(args.output_dir, "eval", name),
                                "wb") as f:
                            torch.save(coco_evaluator.coco_eval["bbox"].eval,
                                       f)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Training time {}".format(total_time_str))
Beispiel #12
0
def train(rank, args, hp, hp_str):
    # if hp.train.num_gpus > 1:
    #     init_process_group(backend=hp.dist.dist_backend, init_method=hp.dist.dist_url,
    #                        world_size=hp.dist.world_size * hp.train.num_gpus, rank=rank)

    torch.cuda.manual_seed(hp.train.seed)
    device = torch.device('cuda:{:d}'.format(rank))

    generator = Generator(hp.model.in_channels,
                          hp.model.out_channels).to(device)
    specd = SpecDiscriminator().to(device)
    msd = MultiScaleDiscriminator().to(device)
    stft_loss = MultiResolutionSTFTLoss()

    if rank == 0:
        print(generator)
        os.makedirs(hp.logs.chkpt_dir, exist_ok=True)
        print("checkpoints directory : ", hp.logs.chkpt_dir)

    if os.path.isdir(hp.logs.chkpt_dir):
        cp_g = scan_checkpoint(hp.logs.chkpt_dir, 'g_')
        cp_do = scan_checkpoint(hp.logs.chkpt_dir, 'do_')

    steps = 0
    if cp_g is None or cp_do is None:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_g = load_checkpoint(cp_g, device)
        state_dict_do = load_checkpoint(cp_do, device)
        generator.load_state_dict(state_dict_g['generator'])
        specd.load_state_dict(state_dict_do['specd'])
        msd.load_state_dict(state_dict_do['msd'])
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if hp.train.num_gpus > 1:
        generator = DistributedDataParallel(generator,
                                            device_ids=[rank]).to(device)
        specd = DistributedDataParallel(specd, device_ids=[rank]).to(device)
        msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)

    optim_g = torch.optim.AdamW(
        generator.parameters(),
        hp.train.adamG.lr,
        betas=[hp.train.adamG.beta1, hp.train.adamG.beta2])
    optim_d = torch.optim.AdamW(
        itertools.chain(msd.parameters(), specd.parameters()),
        hp.train.adamD.lr,
        betas=[hp.train.adamD.beta1, hp.train.adamD.beta2])

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])

    # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hp.train.adam.lr_decay, last_epoch=last_epoch)
    # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hp.train.adam.lr_decay, last_epoch=last_epoch)

    training_filelist, validation_filelist = get_dataset_filelist(args)

    trainset = MelDataset(training_filelist,
                          hp.data.input_wavs,
                          hp.data.output_wavs,
                          hp.audio.segment_length,
                          hp.audio.filter_length,
                          hp.audio.n_mel_channels,
                          hp.audio.hop_length,
                          hp.audio.win_length,
                          hp.audio.sampling_rate,
                          hp.audio.mel_fmin,
                          hp.audio.mel_fmax,
                          n_cache_reuse=0,
                          shuffle=False if hp.train.num_gpus > 1 else True,
                          fmax_loss=None,
                          device=device)

    train_sampler = DistributedSampler(
        trainset) if hp.train.num_gpus > 1 else None

    train_loader = DataLoader(trainset,
                              num_workers=hp.train.num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=hp.train.batch_size,
                              pin_memory=True,
                              drop_last=True)

    if rank == 0:
        validset = MelDataset(validation_filelist,
                              hp.data.input_wavs,
                              hp.data.output_wavs,
                              hp.audio.segment_length,
                              hp.audio.filter_length,
                              hp.audio.n_mel_channels,
                              hp.audio.hop_length,
                              hp.audio.win_length,
                              hp.audio.sampling_rate,
                              hp.audio.mel_fmin,
                              hp.audio.mel_fmax,
                              split=False,
                              shuffle=False,
                              n_cache_reuse=0,
                              fmax_loss=None,
                              device=device)
        validation_loader = DataLoader(validset,
                                       num_workers=1,
                                       shuffle=False,
                                       sampler=None,
                                       batch_size=1,
                                       pin_memory=True,
                                       drop_last=True)

        sw = SummaryWriter(os.path.join(hp.logs.chkpt_dir, 'logs'))

    generator.train()
    specd.train()
    msd.train()
    with_postnet = False
    for epoch in range(max(0, last_epoch), args.training_epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch + 1))

        if hp.train.num_gpus > 1:
            train_sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            if rank == 0:
                start_b = time.time()
            if steps > hp.train.postnet_start_steps:
                with_postnet = True
            x, y, file, _, y_mel_loss = batch
            x = torch.autograd.Variable(x.to(device, non_blocking=True))
            y = torch.autograd.Variable(y.to(device, non_blocking=True))
            y_mel_loss = torch.autograd.Variable(
                y_mel_loss.to(device, non_blocking=True))
            # y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
            x = x.unsqueeze(1)
            y = y.unsqueeze(1)
            before_y_g_hat, y_g_hat = generator(x, with_postnet)

            if y_g_hat is not None:
                y_g_hat_mel = mel_spectrogram(
                    y_g_hat.squeeze(1), hp.audio.filter_length,
                    hp.audio.n_mel_channels, hp.audio.sampling_rate,
                    hp.audio.hop_length, hp.audio.win_length,
                    hp.audio.mel_fmin, None)

            if steps > hp.train.discriminator_train_start_steps:
                for _ in range(hp.train.rep_discriminator):
                    optim_d.zero_grad()

                    # SpecD
                    y_df_hat_r, y_df_hat_g, _, _ = specd(
                        y_mel_loss, y_g_hat_mel.detach())
                    loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
                        y_df_hat_r, y_df_hat_g)

                    # MSD
                    y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
                    loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
                        y_ds_hat_r, y_ds_hat_g)

                    loss_disc_all = loss_disc_s + loss_disc_f

                    loss_disc_all.backward()
                    optim_d.step()

            before_y_g_hat_mel = mel_spectrogram(
                before_y_g_hat.squeeze(1), hp.audio.filter_length,
                hp.audio.n_mel_channels, hp.audio.sampling_rate,
                hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin,
                None)
            # Generator
            optim_g.zero_grad()

            # L1 Mel-Spectrogram Loss
            # before_loss_mel = F.l1_loss(y_mel_loss, before_y_g_hat_mel)
            sc_loss, mag_loss = stft_loss(
                before_y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
            before_loss_mel = sc_loss + mag_loss

            # L1 Sample Loss
            before_loss_sample = F.l1_loss(y, before_y_g_hat)
            loss_gen_all = before_loss_mel + before_loss_sample

            if y_g_hat is not None:
                # L1 Mel-Spectrogram Loss
                # loss_mel = F.l1_loss(y_mel_loss, y_g_hat_mel)
                sc_loss_, mag_loss_ = stft_loss(
                    y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
                loss_mel = sc_loss_ + mag_loss_
                # L1 Sample Loss
                loss_sample = F.l1_loss(y, y_g_hat)
                loss_gen_all += loss_mel + loss_sample

            if steps > hp.train.discriminator_train_start_steps:
                y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = specd(
                    y_mel_loss, y_g_hat_mel)
                y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
                loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
                loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
                loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
                loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
                loss_gen_all += hp.model.lambda_adv * (
                    loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f)

            loss_gen_all.backward()
            optim_g.step()

            if rank == 0:
                # STDOUT logging
                if steps % args.stdout_interval == 0:
                    with torch.no_grad():
                        mel_error = F.l1_loss(y_mel_loss,
                                              before_y_g_hat_mel).item()
                        sample_error = F.l1_loss(y, before_y_g_hat)

                    print(
                        'Steps : {:d}, Gen Loss Total : {:4.3f}, Sample Error: {:4.3f}, '
                        'Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.format(
                            steps, loss_gen_all, sample_error, mel_error,
                            time.time() - start_b))

                # checkpointing
                if steps % hp.logs.save_interval == 0 and steps != 0:
                    checkpoint_path = "{}/g_{:08d}".format(
                        hp.logs.chkpt_dir, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'generator':
                            (generator.module if hp.train.num_gpus > 1 else
                             generator).state_dict()
                        })
                    checkpoint_path = "{}/do_{:08d}".format(
                        hp.logs.chkpt_dir, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'specd': (specd.module if hp.train.num_gpus > 1
                                      else specd).state_dict(),
                            'msd': (msd.module if hp.train.num_gpus > 1 else
                                    msd).state_dict(),
                            'optim_g':
                            optim_g.state_dict(),
                            'optim_d':
                            optim_d.state_dict(),
                            'steps':
                            steps,
                            'epoch':
                            epoch,
                            'hp_str':
                            hp_str
                        })

                # Tensorboard summary logging
                if steps % hp.logs.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", loss_gen_all,
                                  steps)
                    sw.add_scalar("training/mel_spec_error", mel_error, steps)

                # Validation
                if steps % hp.logs.validation_interval == 0:  # and steps != 0:
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    with torch.no_grad():
                        for j, batch in enumerate(validation_loader):
                            x, y, file, y_mel, y_mel_loss = batch
                            x = x.unsqueeze(1)
                            y = y.unsqueeze(1).to(device)
                            before_y_g_hat, y_g_hat = generator(x.to(device))
                            y_mel_loss = torch.autograd.Variable(
                                y_mel_loss.to(device, non_blocking=True))
                            y_g_hat_mel = mel_spectrogram(
                                before_y_g_hat.squeeze(1),
                                hp.audio.filter_length,
                                hp.audio.n_mel_channels,
                                hp.audio.sampling_rate, hp.audio.hop_length,
                                hp.audio.win_length, hp.audio.mel_fmin, None)
                            val_err_tot += F.l1_loss(y_mel_loss,
                                                     y_g_hat_mel).item()
                            val_err_tot += F.l1_loss(y, before_y_g_hat).item()
                            if y_g_hat is not None:
                                val_err_tot += F.l1_loss(y, y_g_hat).item()

                            if j <= 4:
                                if steps == 0:
                                    sw.add_audio('gt_noise/y_{}'.format(j),
                                                 x[0], steps,
                                                 hp.audio.sampling_rate)
                                    sw.add_audio('gt_clean/y_{}'.format(j),
                                                 y[0], steps,
                                                 hp.audio.sampling_rate)
                                    sw.add_figure(
                                        'gt/y_spec_clean_{}'.format(j),
                                        plot_spectrogram(y_mel[0]), steps)

                                sw.add_audio('generated/y_hat_{}'.format(j),
                                             before_y_g_hat[0], steps,
                                             hp.audio.sampling_rate)
                                if y_g_hat is not None:
                                    sw.add_audio(
                                        'generated/y_hat_after_{}'.format(j),
                                        y_g_hat[0], steps,
                                        hp.audio.sampling_rate)
                                y_hat_spec = mel_spectrogram(
                                    before_y_g_hat.squeeze(1),
                                    hp.audio.filter_length,
                                    hp.audio.n_mel_channels,
                                    hp.audio.sampling_rate,
                                    hp.audio.hop_length, hp.audio.win_length,
                                    hp.audio.mel_fmin, None)
                                sw.add_figure(
                                    'generated/y_hat_spec_{}'.format(j),
                                    plot_spectrogram(
                                        y_hat_spec.squeeze(0).cpu().numpy()),
                                    steps)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/mel_spec_error", val_err,
                                      steps)

                    generator.train()

            steps += 1

        # scheduler_g.step()
        # scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
Beispiel #13
0
def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    # align with DETR format
    args.dataset_file = 'ImageNet'
    args.masks = None
    # freeze cnn weights
    args.lr_backbone = 0 if args.fre_cnn else args.lr
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    dataset_train = build_dataset(image_set='train', args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   collate_fn=utils.updetr_collate_fn,
                                   num_workers=args.num_workers)

    print(len(data_loader_train) * args.epochs)

    output_dir = Path(args.output_dir)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            if lr_scheduler.step_size != args.lr_drop:
                lr_scheduler.step_size = args.lr_drop
            args.start_epoch = checkpoint['epoch'] + 1

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch,
                                      args.clip_max_norm)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 20 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 20 == 0:
                checkpoint_paths.append(output_dir /
                                        f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Beispiel #14
0
def run_process():
    '''Run process

    This is what is actually run on each process.
    '''
    # Get distributed parameters
    rank = dist.get_rank()
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = dist.get_world_size()

    # Initialize data_loader
    context_size = 512
    batch_size = 32
    corpus_length = 1024
    vocab_size = 2**8

    dataset = RandomCorpus(corpus_length, context_size, vocab_size)
    sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
    )

    # Initialize model
    model = GPT(vocab_size, context_size, verbose=True)

    device = torch.device(f"cuda:{local_rank}")
    model.to(device)

    # Prepare for distributed data parallelism
    model = DistributedDataParallel(model,
                                    device_ids=[rank],
                                    output_device=rank)

    # The learning rate is adapted for the total batch_size in tokens
    learning_rate = 6e-4 * (batch_size * world_size * context_size / 5e5)
    # ZeroRedundancyOptimizer reduces the memory footprint of the Optimizer
    opt = ZeroRedundancyOptimizer(
        model.parameters(),
        optimizer_class=optim.Adam,
        lr=learning_rate,
    )
    loss_func = nn.CrossEntropyLoss()

    # Initialize logger instance to see performance
    writer = BenchmarkWriter()

    # Actual training
    global_step = 0
    n_epochs = 10
    for epoch in range(n_epochs):
        model.train()
        sampler.set_epoch(epoch)  # for correct shuffling
        for sequence, in data_loader:
            opt.zero_grad()

            # Shift so that prediction is next token for each token
            sequence = sequence.to(device)
            logits = model(sequence[..., :-1].contiguous())
            target = sequence[..., 1:].contiguous()

            # Flatten the tokens when calculating loss
            loss = loss_func(
                logits.flatten(end_dim=-2),
                target.flatten(),
            )
            loss.backward()
            opt.step()

            # This will also log the wall time
            if rank == 0:
                global_step += batch_size * world_size
                writer.add_scalar("Loss", loss.item(), global_step=global_step)

        if rank == 0:
            print("Epoch:", epoch)

    if rank == 0:
        writer.benchmark_results(burn_in_steps=2 * corpus_length,
                                 step_unit="seq")
    writer.close()

    return model
Beispiel #15
0
def train(rank, a, h):
    if h.num_gpus > 1:
        init_process_group(backend=h.dist_config['dist_backend'],
                           init_method=h.dist_config['dist_url'],
                           world_size=h.dist_config['world_size'] * h.num_gpus,
                           rank=rank)

    torch.cuda.manual_seed(h.seed)
    device = torch.device('cuda:{:d}'.format(rank))

    generator = Generator(h).to(device)
    mpd = MultiPeriodDiscriminator(
        h["discriminator_periods"] if "discriminator_periods" in
        h.keys() else None).to(device)
    msd = MultiScaleDiscriminator().to(device)

    if rank == 0:
        print(generator)
        os.makedirs(a.checkpoint_path, exist_ok=True)
        print("checkpoints directory : ", a.checkpoint_path)

    if os.path.isdir(a.checkpoint_path):
        cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
        cp_do = scan_checkpoint(a.checkpoint_path, 'do_')

    steps = 0
    if cp_g is not None:
        state_dict_g = load_checkpoint(cp_g, device)
        gsd = generator.state_dict()
        gsd.update({
            k: v
            for k, v in state_dict_g['generator'].items()
            if k in gsd and state_dict_g['generator'][k].shape == gsd[k].shape
        })
        missing_keys = {
            k: v
            for k, v in state_dict_g['generator'].items()
            if not (k in gsd
                    and state_dict_g['generator'][k].shape == gsd[k].shape)
        }.keys()
        generator.load_state_dict(gsd)
        del gsd, state_dict_g

    if cp_do is None or len(missing_keys) or a.from_zero:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_do = load_checkpoint(cp_do, device)
        mpd.load_state_dict(state_dict_do['mpd'])
        del state_dict_do['mpd']
        msd.load_state_dict(state_dict_do['msd'])
        del state_dict_do['msd']
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if h.num_gpus > 1:
        generator = DistributedDataParallel(generator,
                                            device_ids=[rank]).to(device)
        mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
        msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)

    optim_g = torch.optim.AdamW(generator.parameters(),
                                h.learning_rate,
                                betas=[h.adam_b1, h.adam_b2])
    optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(),
                                                mpd.parameters()),
                                h.learning_rate,
                                betas=[h.adam_b1, h.adam_b2])

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])
        del state_dict_do

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g,
                                                         gamma=h.lr_decay,
                                                         last_epoch=last_epoch)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d,
                                                         gamma=h.lr_decay,
                                                         last_epoch=last_epoch)

    training_filelist, validation_filelist = get_dataset_filelist(
        a, h.segment_size, h.sampling_rate)

    trainset = MelDataset(training_filelist,
                          h.segment_size,
                          h.n_fft,
                          h.num_mels,
                          h.hop_size,
                          h.win_size,
                          h.sampling_rate,
                          h.fmin,
                          h.fmax,
                          n_cache_reuse=0,
                          shuffle=False if h.num_gpus > 1 else True,
                          fmax_loss=h.fmax_for_loss,
                          device=device,
                          fine_tuning=a.fine_tuning,
                          trim_non_voiced=a.trim_non_voiced)

    STFT = STFT_Class(h.sampling_rate, h.num_mels, h.n_fft, h.win_size,
                      h.hop_size, h.fmin, h.fmax)

    train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None

    train_loader = DataLoader(trainset,
                              num_workers=h.num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=h.batch_size,
                              pin_memory=True,
                              drop_last=True)
    assert len(train_loader), 'No audio files in dataset!'

    if rank == 0:
        validset = MelDataset(validation_filelist,
                              h.segment_size,
                              h.n_fft,
                              h.num_mels,
                              h.hop_size,
                              h.win_size,
                              h.sampling_rate,
                              h.fmin,
                              h.fmax,
                              False,
                              False,
                              n_cache_reuse=0,
                              fmax_loss=h.fmax_for_loss,
                              device=device,
                              fine_tuning=a.fine_tuning,
                              trim_non_voiced=a.trim_non_voiced)
        validation_loader = DataLoader(validset,
                                       num_workers=h.num_workers,
                                       shuffle=False,
                                       sampler=None,
                                       batch_size=1,
                                       pin_memory=True,
                                       drop_last=True)

        sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'),
                           max_queue=10000)
        sw.logged_gt_plots = False

    if h.num_gpus > 1:
        import gc
        gc.collect()
        torch.cuda.empty_cache()

    generator.train()
    mpd.train()
    msd.train()
    for epoch in range(max(0, last_epoch), a.training_epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch + 1))

        if h.num_gpus > 1:
            train_sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            if rank == 0:
                start_b = time.time()
            x, y, _, y_mel = batch
            x = torch.autograd.Variable(x.to(device, non_blocking=True))
            y = torch.autograd.Variable(y.to(device, non_blocking=True))
            y_mel = torch.autograd.Variable(y_mel.to(device,
                                                     non_blocking=True))
            y = y.unsqueeze(1)

            y_g_hat = generator(x)
            y_g_hat_mel = STFT.get_mel(y_g_hat.squeeze(1))

            optim_d.zero_grad()

            # MPD
            y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
            loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
                y_df_hat_r, y_df_hat_g)

            # MSD
            y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
            loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
                y_ds_hat_r, y_ds_hat_g)

            loss_disc_all = loss_disc_s + loss_disc_f

            loss_disc_all.backward()
            optim_d.step()

            # Generator
            optim_g.zero_grad()

            # L1 Mel-Spectrogram Loss
            loss_mel = F.l1_loss(y_mel, y_g_hat_mel)

            y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
            y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
            loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
            loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
            loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
            loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
            loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel * 45

            loss_gen_all.backward()
            optim_g.step()

            if rank == 0:
                torch.set_grad_enabled(False)
                # STDOUT logging
                if steps % a.stdout_interval == 0:
                    print(
                        'Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'
                        .format(steps, loss_gen_all, loss_mel.item(),
                                time.time() - start_b))

                # checkpointing
                if steps % a.checkpoint_interval == 0 and steps != 0:
                    checkpoint_path = "{}/g_{:08d}".format(
                        a.checkpoint_path, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'generator': (generator.module if h.num_gpus > 1
                                          else generator).state_dict()
                        })
                    checkpoint_path = "{}/do_{:08d}".format(
                        a.checkpoint_path, steps)
                    save_checkpoint(
                        checkpoint_path, {
                            'mpd': (mpd.module
                                    if h.num_gpus > 1 else mpd).state_dict(),
                            'msd': (msd.module
                                    if h.num_gpus > 1 else msd).state_dict(),
                            'optim_g':
                            optim_g.state_dict(),
                            'optim_d':
                            optim_d.state_dict(),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        })
                    del_old_checkpoints(a.checkpoint_path, 'g_',
                                        a.n_models_to_keep)
                    del_old_checkpoints(a.checkpoint_path, 'do_',
                                        a.n_models_to_keep)

                # Tensorboard summary logging
                if steps % a.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", loss_gen_all,
                                  steps)
                    sw.add_scalar("training/mel_spec_error", loss_mel.item(),
                                  steps)

                # Validation
                if steps % a.validation_interval == 0:  # and steps != 0:
                    print("Validating...")
                    n_audios_to_plot = 6
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    for j, batch in tqdm(enumerate(validation_loader),
                                         total=len(validation_loader)):
                        x, y, _, y_mel = batch
                        y_g_hat = generator(x.to(device))
                        y_hat_spec = STFT.get_mel(y_g_hat.squeeze(1))
                        val_err_tot += F.l1_loss(y_mel,
                                                 y_hat_spec.to(y_mel)).item()

                        if j < n_audios_to_plot and not sw.logged_gt_plots:
                            sw.add_audio(f'gt/y_{j}', y[0], steps,
                                         h.sampling_rate)
                            sw.add_figure(f'spec_{j:02}/gt_spec',
                                          plot_spectrogram(y_mel[0]), steps)
                        if j < n_audios_to_plot:
                            sw.add_audio(f'generated/y_hat_{j}', y_g_hat[0],
                                         steps, h.sampling_rate)
                            sw.add_figure(
                                f'spec_{j:02}/pred_spec',
                                plot_spectrogram(
                                    y_hat_spec.squeeze(0).cpu().numpy()),
                                steps)

                        if j > 64:  # I am NOT patient enough to complete an entire validation cycle with over 1536 files.
                            break
                    sw.logged_gt_plots = True
                    val_err = val_err_tot / (j + 1)
                    sw.add_scalar("validation/mel_spec_error", val_err, steps)
                    generator.train()
                    print(f"Done. Val_loss = {val_err}")
                torch.set_grad_enabled(True)
            steps += 1

        scheduler_g.step()
        scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
Beispiel #16
0
def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print("number of params:", n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    dataset_train = build_dataset(image_set="train", args=args)
    dataset_val = build_dataset(image_set="val", args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(
        dataset_train,
        batch_sampler=batch_sampler_train,
        collate_fn=utils.collate_fn,
        num_workers=args.num_workers,
    )
    data_loader_val = DataLoader(
        dataset_val,
        args.batch_size if args.batch_size < 4 else 4,
        sampler=sampler_val,
        drop_last=False,
        collate_fn=utils.collate_fn,
        num_workers=args.num_workers,
    )

    if args.dataset_file == "coco_panoptic":
        # We also evaluate AP during panoptic training, on original coco DS
        coco_val = datasets.coco.build("val", args)
        base_ds = get_coco_api_from_dataset(coco_val)
    elif args.dataset_file in ["cmdd", "cmdc", "wider"]:
        base_ds = None
    elif args.dataset_file == "MOT17":
        base_ds = dataset_val
    else:
        base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location="cpu")
        model_without_ddp.detr.load_state_dict(checkpoint["model"])

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith("https"):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location="cpu",
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location="cpu")

        # NOTE: this is Bruno's hack to load stuff in
        model_dict = model_without_ddp.state_dict()
        pretrained_dict = checkpoint["model"]
        # hack for adding query stuff
        if ("query_embed.query_embed.weight" in model_dict.keys()
                and "query_embed.weight" in pretrained_dict.keys()):
            pretrained_dict[
                "query_embed.query_embed.weight"] = pretrained_dict[
                    "query_embed.weight"]
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # if finetuning skip the linear stuff
        if args.finetune:
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if k not in ["class_embed.weight", "class_embed.bias"]
            }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load new state dict
        model_without_ddp.load_state_dict(model_dict)

        if (not args.eval and not args.load_model_only
                and "optimizer" in checkpoint and "lr_scheduler" in checkpoint
                and "epoch" in checkpoint):
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            args.start_epoch = checkpoint["epoch"] + 1

    if args.eval:
        if args.test and args.dataset_file == "wider":
            if args.resume:
                s = args.resume.split("/")[:-1]
                output_dir = "/" + os.path.join(*s)
            else:
                output_dir = args.output_dir
            print("SAVING TEST WIDER TO ", output_dir)
            test_wider(
                model,
                criterion,
                postprocessors,
                dataset_val,
                data_loader_val,
                device,
                output_dir,
            )
            return
        test_stats, coco_evaluator = evaluate(
            model,
            criterion,
            postprocessors,
            data_loader_val,
            base_ds,
            device,
            args.output_dir,
            dset_file=args.dataset_file,
        )
        if args.output_dir and coco_evaluator is not None:
            utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval,
                                 output_dir / "eval.pth")
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            args.clip_max_norm,
        )
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / "checkpoint.pth"]
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
                checkpoint_paths.append(output_dir /
                                        f"checkpoint{epoch:04}.pth")
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        "model": model_without_ddp.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "lr_scheduler": lr_scheduler.state_dict(),
                        "epoch": epoch,
                        "args": args,
                    },
                    checkpoint_path,
                )

        test_stats, coco_evaluator = evaluate(
            model,
            criterion,
            postprocessors,
            data_loader_val,
            base_ds,
            device,
            args.output_dir,
            dset_file=args.dataset_file,
        )

        log_stats = {
            **{f"train_{k}": v
               for k, v in train_stats.items()},
            **{f"test_{k}": v
               for k, v in test_stats.items()},
            "epoch": epoch,
            "n_parameters": n_parameters,
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

            # for evaluation logs
            if coco_evaluator is not None:
                (output_dir / "eval").mkdir(exist_ok=True)
                if "bbox" in coco_evaluator.coco_eval:
                    filenames = ["latest.pth"]
                    if epoch % 50 == 0:
                        filenames.append(f"{epoch:03}.pth")
                    for name in filenames:
                        torch.save(
                            coco_evaluator.coco_eval["bbox"].eval,
                            output_dir / "eval" / name,
                        )

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Training time {}".format(total_time_str))
Beispiel #17
0
def train(model, optimizer, scheduler, global_step, train_dataset, dev_dataset,
          opt, collator, best_eval_loss):

    if opt.is_main:
        try:
            tb_logger = torch.utils.tensorboard.SummaryWriter(
                Path(opt.checkpoint_dir) / opt.name)
        except:
            tb_logger = None
            logger.warning('Tensorboard is not available.')
    train_sampler = DistributedSampler(
        train_dataset) if opt.is_distributed else RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=opt.per_gpu_batch_size,
                                  drop_last=True,
                                  num_workers=10,
                                  collate_fn=collator)

    loss, curr_loss = 0.0, 0.0
    epoch = 1
    model.train()
    while global_step < opt.total_steps:
        if opt.is_distributed > 1:
            train_sampler.set_epoch(epoch)
        epoch += 1
        for i, batch in enumerate(train_dataloader):
            global_step += 1
            (idx, question_ids, question_mask, passage_ids, passage_mask,
             gold_score) = batch
            _, _, _, train_loss = model(
                question_ids=question_ids.cuda(),
                question_mask=question_mask.cuda(),
                passage_ids=passage_ids.cuda(),
                passage_mask=passage_mask.cuda(),
                gold_score=gold_score.cuda(),
            )

            train_loss.backward()

            if global_step % opt.accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
                optimizer.step()
                scheduler.step()
                model.zero_grad()

            train_loss = src.util.average_main(train_loss, opt)
            curr_loss += train_loss.item()

            if global_step % opt.eval_freq == 0:
                eval_loss, inversions, avg_topk, idx_topk = evaluate(
                    model, dev_dataset, collator, opt)
                if eval_loss < best_eval_loss:
                    best_eval_loss = eval_loss
                    if opt.is_main:
                        src.util.save(model, optimizer, scheduler, global_step,
                                      best_eval_loss, opt, dir_path,
                                      'best_dev')
                model.train()
                if opt.is_main:
                    log = f"{global_step} / {opt.total_steps}"
                    log += f" -- train: {curr_loss/opt.eval_freq:.6f}"
                    log += f", eval: {eval_loss:.6f}"
                    log += f", inv: {inversions:.1f}"
                    log += f", lr: {scheduler.get_last_lr()[0]:.6f}"
                    for k in avg_topk:
                        log += f" | avg top{k}: {100*avg_topk[k]:.1f}"
                    for k in idx_topk:
                        log += f" | idx top{k}: {idx_topk[k]:.1f}"
                    logger.info(log)

                    if tb_logger is not None:
                        tb_logger.add_scalar("Evaluation", eval_loss,
                                             global_step)
                        tb_logger.add_scalar("Training",
                                             curr_loss / (opt.eval_freq),
                                             global_step)
                    curr_loss = 0

            if opt.is_main and global_step % opt.save_freq == 0:
                src.util.save(model, optimizer, scheduler, global_step,
                              best_eval_loss, opt, dir_path,
                              f"step-{global_step}")
            if global_step > opt.total_steps:
                break
Beispiel #18
0
class CTLTrainer(Trainer):
    def __init__(
        self,
        model: nn.Module,
        train_dataset: TSBaseDataset,
        valid_dataset: TSBaseDataset,
        test_dataset: TSBaseDataset,
        optimizer,
        evaluator: MetricEvaluator,
        criterion,
        config,
    ):
        self.config = config

        self._stop_training = False

        self.metrics = {}

        callbacks = [
            hydra.utils.call(callback_config)
            for callback_config in self.config.trainer.callback.values()
        ]
        self.callbacks = CTLCallbackContainer(self, callbacks)

        self.world_size = self.config.device.get("world_size", 1)
        train_dataset = sample_data(
            train_dataset, self.config.dataset.get("train_samples", -1))
        valid_dataset = sample_data(
            valid_dataset, self.config.dataset.get("valid_samples", -1))
        self.valid_dataset_len = len(valid_dataset)
        self.train_dataset_len = len(train_dataset)
        self.train_sampler = None
        self.valid_sampler = None
        if self.world_size > 1:
            local_rank = int(
                self.config.device.get("local_rank",
                                       os.environ.get("LOCAL_RANK", 0)))
            self.device = get_device(local_rank,
                                     self.config.device.get("name", "cpu"))
            self.is_distributed = init_distributed(
                int(
                    self.config.device.get("world_size",
                                           os.environ.get("WORLD_SIZE", 1))))
            torch.cuda.synchronize()
            self.train_sampler = DistributedSampler(train_dataset,
                                                    config.device.world_size,
                                                    seed=config.trainer.get(
                                                        "seed", 0),
                                                    drop_last=True)
            self.valid_sampler = DistributedSampler(valid_dataset,
                                                    config.device.world_size,
                                                    seed=config.trainer.get(
                                                        "seed", 0),
                                                    drop_last=False)
        elif self.config.device.get("local_rank", None):
            self.device = get_device(self.config.device.get("local_rank"),
                                     self.config.device.get("name", "cpu"))
        else:
            self.device = torch.device(self.config.device.get("name", "cpu"))
        self.logger = setup_logger(self.config)
        self.optimizer = optimizer
        self.amp_enabled = self.config.trainer.get("AMP", False)
        self.model = model.to(self.device)

        if config.trainer.get("ema", None) is not None:
            self.ema = ModelEmaV2(config, model, self.device)
        else:
            self.ema = None
        if self.amp_enabled:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O2",
                                                        loss_scale="dynamic")
        if self.world_size > 1:
            self.model = DDP(self.model,
                             device_ids=[local_rank],
                             output_device=local_rank,
                             find_unused_parameters=True)

        # TODO: line below has to go somewhere else. Or use default print. Logging module alters std streams which prevents us from
        # capturing their outputs.
        # log(config.pretty())

        # XXX: Not sure about this. Maybe this should be isolated in collate_fn inside a DataLoader. Or maybe we should abstract it away in data_utils?
        # For sure we have to rename this. This suggests that masked target is somehow different from
        # regular target.
        self.train_target = "target_masked" if config.model.get(
            "train_target_mask", True) else "target"
        self.eval_target = "target_masked" if config.model.get(
            "eval_target_mask", True) else "target"
        self.test_target = "target_masked" if config.model.get(
            "test_target_mask", True) else "target"

        if self.config.dataset.get("graph", False) and self.config.model.get(
                "graph_eligible", False):

            def _collate_graph(samples, target):
                batch = dgl.batch(samples)
                labels = batch.ndata["target"]
                # XXX: we need discuss how to do this neatly
                if target == "target_masked":
                    labels = labels[:, self.config.dataset.encoder_length:, :]

                return batch, labels

            _collate = _collate_graph
        else:

            def _collate_dict(samples, target):
                batch = default_collate(samples)
                labels = batch["target"]
                if target == "target_masked":
                    labels = labels[:, self.config.dataset.encoder_length:, :]
                return batch, labels

            _collate = _collate_dict

        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.trainer.batch_size,
            num_workers=self.config.trainer.num_workers,
            sampler=self.train_sampler,
            shuffle=True if self.train_sampler is None else False,
            pin_memory=True,
            collate_fn=partial(_collate, target=self.train_target),
        )
        self.valid_dataloader = DataLoader(
            valid_dataset,
            batch_size=self.config.trainer.batch_size,
            num_workers=self.config.trainer.num_workers,
            sampler=self.valid_sampler,
            shuffle=True if self.valid_sampler is None else False,
            pin_memory=True,
            collate_fn=partial(_collate, target=self.eval_target),
        )
        self.test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.config.trainer.batch_size,
            num_workers=1,
            pin_memory=True,
            collate_fn=partial(_collate, target=self.test_target),
        )
        if self.config.get("scheduler", None):
            self.scheduler = hydra.utils.instantiate(self.config.scheduler,
                                                     optimizer)
        else:
            self.scheduler = None

        self.evaluator = evaluator
        self.criterion = criterion

        self.log_path = self.config.get("log_path", os.getcwd())
        self.global_step = 0
        self.epoch = 0

        self.preds_train_output_selector = config.model.get(
            "preds_train_output_selector", -1)
        self.preds_eval_output_selector = config.model.get(
            "preds_eval_output_selector", -1)
        self.preds_test_output_selector = config.model.get(
            "preds_test_output_selector", -1)

        model_ref = self.model.module if self.world_size > 1 else self.model
        test_method_name = config.model.get("test_method", "__call__")
        self.test_method = getattr(model_ref, test_method_name)

        checkpoint_path = config.trainer.get("checkpoint_path", None)
        maybe_restore_checkpoint(self, checkpoint_path)

    def assess_valid(self):
        self.model.eval()
        with torch.no_grad():
            running_losses = 0

            for i, (batch, labels) in enumerate(self.valid_dataloader):
                batch = to_device(batch, device=self.device)
                labels = to_device(labels, device=self.device)
                if self.ema:
                    preds = self.ema.module(batch)
                else:
                    preds = self.model(batch)
                if self.preds_eval_output_selector >= 0:
                    preds = preds[..., self.preds_eval_output_selector:self.
                                  preds_eval_output_selector + 1]

                losses = self.criterion(preds, labels)
                losses = reduce_tensor(losses, self.world_size).detach()
                running_losses += losses

        running_losses = running_losses / (len(self.valid_dataloader.dataset) /
                                           self.config.trainer.batch_size)
        if len(running_losses.size()) < 1:
            running_losses = running_losses.unsqueeze(0)
        running_losses = [loss.item() for loss in running_losses]
        data = {"val_loss": sum(running_losses)}
        for i, elem in enumerate(running_losses):
            data["val_loss_component_" + str(i)] = elem
        self.logger.log(step=self.global_step,
                        data=data,
                        verbosity=dllogger.Verbosity.VERBOSE)

        self.model.train()
        return sum(running_losses)

    def train(self):

        self.callbacks.on_train_begin()
        self.global_step = 0
        for epoch in range(self.epoch, self.config.trainer.num_epochs):
            self.callbacks.on_epoch_begin(epoch)

            self.logger.log(step=self.global_step,
                            data={"epoch": epoch},
                            verbosity=dllogger.Verbosity.VERBOSE)

            for i, (batch, labels) in enumerate(self.train_dataloader):
                self.callbacks.on_batch_begin(i)

                self.optimizer.zero_grad()
                batch = to_device(batch, device=self.device)
                labels = to_device(labels, device=self.device)

                preds = self.model(batch)
                if self.preds_train_output_selector >= 0:
                    preds = preds[..., self.preds_train_output_selector:self.
                                  preds_train_output_selector + 1]

                losses = self.criterion(preds, labels)
                loss = losses.sum()

                if self.amp_enabled:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                self.optimizer.step()

                losses = reduce_tensor(losses, self.world_size, average=True)
                if len(losses.size()) < 1:
                    losses = [losses]
                losses = [loss.item() for loss in losses]
                data = {"loss": loss.item()}
                for k, v in enumerate(losses):
                    data["loss_component_" + str(k)] = v

                self.logger.log(step=self.global_step,
                                data=data,
                                verbosity=dllogger.Verbosity.VERBOSE)

                if self.config.optimizer.get("gradient_norm", 0.0) > 0:
                    nn.utils.clip_grad_norm(
                        self.model.parameters(),
                        self.config.optimizer.gradient_norm)
                # XXX: shouldn't we move logging to a callback?
                if self.global_step % self.config.trainer.log_interval == 0:
                    self.logger.flush()
                self.global_step += 1
                self.callbacks.on_batch_end(i, logs=data)
                if self.ema:
                    self.ema.update(self.model)
            if self.scheduler:
                self.scheduler.step()
            self.callbacks.on_valid_begin(epoch)
            validation_loss = self.assess_valid()
            data = {"val_loss": validation_loss}
            self.callbacks.on_valid_end(epoch, logs=data)

            if is_main_process():
                save_checkpoint(self, checkpoint_dir=self.log_path)

            if self.train_sampler:
                self.train_sampler.set_epoch(epoch)
                self.valid_sampler.set_epoch(epoch)

            self.callbacks.on_epoch_end(epoch, logs=data)

            if self._stop_training:
                break

        self.callbacks.on_train_end(logs=self.metrics)

    def evaluate(self):
        self.callbacks.on_evaluate_begin()
        maybe_restore_checkpoint(
            self, os.path.join(self.log_path, "best_checkpoint.pth.tar"))
        self.model.eval()

        with torch.no_grad():

            preds_full = []
            labels_full = []
            weights_full = []
            ids_full = []

            for i, (batch, labels) in enumerate(self.test_dataloader):
                batch = to_device(batch, device=self.device)
                labels = to_device(labels, device=self.device)

                if self.config.evaluator.get("use_weights", False):
                    weights = batch["weight"]
                else:
                    weights = None

                # XXX we should abstract this away
                ids = batch.ndata["id"] if isinstance(
                    batch, dgl.DGLGraph) else batch["id"]
                ids = ids[:, 0,
                          ...]  # Assumes that time dimension is at index 1. We don't check whether te examle is constructed correctly

                labels_full.append(labels)
                weights_full.append(weights)
                preds = self.test_method(batch)
                if self.preds_test_output_selector >= 0:
                    preds = preds[..., self.preds_test_output_selector:self.
                                  preds_test_output_selector + 1]
                ids_full.append(ids)
                preds_full.append(preds)

            preds_full = torch.cat(preds_full, dim=0).cpu().numpy()
            labels_full = torch.cat(labels_full, dim=0).cpu().numpy()

            if self.config.evaluator.get("use_weights", False):
                weights_full = torch.cat(weights_full).cpu().numpy()
            else:
                weights_full = np.zeros((0, 0))
            ids_full = torch.cat(ids_full).cpu().numpy()
            eval_metrics = self.evaluator(labels_full, preds_full,
                                          weights_full, ids_full)

            self.metrics.update(eval_metrics)

            self.logger.log(
                step=[],
                data={k: float(v)
                      for k, v in self.metrics.items()},
                verbosity=dllogger.Verbosity.VERBOSE)
            self.callbacks.on_evaluate_end(
                logs=round_dict(self.metrics, decimal=3))
            return round_dict(self.metrics, decimal=3)
    def run(self,
            dataset: torch.utils.data.Dataset,
            memory_set: torch.utils.data.Dataset = None,
            query_set: torch.utils.data.Dataset = None,
            save_every: int = 100,
            **kwargs):

        if not self.prepared:
            raise RuntimeError("Training not prepared.")

        # DataLoader (for self-supervised pre-training)
        sampler = DistributedSampler(dataset) if self.distributed else None
        shuffle = not self.distributed
        data_loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            shuffle=shuffle,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True
        )

        # DataLoader (for supervised evaluation)
        if (memory_set is not None) and (query_set is not None):
            memory_loader = DataLoader(memory_set, batch_size=self.batch_size*2, num_workers=self.num_workers)
            query_loader = DataLoader(query_set, batch_size=self.batch_size*2)
            knn_eval = True
        else:
            query_loader = None
            memory_loader = None
            knn_eval = False

        # Logging
        logger = kwargs.get('logger', None)

        for epoch in range(1, self.epochs + 1):

            if self.distributed and (sampler is not None):
                sampler.set_epoch(epoch)

            # Train
            history = self.train(data_loader)
            log = " | ".join([f"{k} : {v:.4f}" for k, v in history.items()])

            # Evaluate
            if (self.local_rank == 0) and knn_eval:
                knn_k = kwargs.get('knn_k', [5, 200])
                knn = KNNEvaluator(knn_k, num_classes=query_loader.dataset.num_classes)
                knn_scores = knn.evaluate(self.net_q,
                                          memory_loader=memory_loader,
                                          query_loader=query_loader)
                for k, score in knn_scores.items():
                    log += f" | knn@{k}: {score*100:.2f}%"
            else:
                knn_scores = None

            # Logging
            if logger is not None:
                logger.info(f"Epoch [{epoch:>4}/{self.epochs:>4}] - " + log)

            # TensorBoard
            if self.writer is not None:
                for k, v in history.items():
                    self.writer.add_scalar(k, v, global_step=epoch)
                if knn_scores is not None:
                    for k, score in knn_scores.items():
                        self.writer.add_scalar(f'knn@{k}', score, global_step=epoch)
                if self.scheduler is not None:
                    lr = self.scheduler.get_last_lr()[0]
                    self.writer.add_scalar('lr', lr, global_step=epoch)

            if (epoch % save_every == 0) & (self.local_rank == 0):
                ckpt = os.path.join(self.ckpt_dir, f"ckpt.{epoch}.pth.tar")
                self.save_checkpoint(ckpt, epoch=epoch, history=history)

            if self.scheduler is not None:
                self.scheduler.step()
Beispiel #20
0
def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    dataset_train = build_dataset(split='train', args=args)
    dataset_val = build_dataset(split='val', args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn,
                                   num_workers=args.num_workers)
    data_loader_val = DataLoader(dataset_val,
                                 args.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 collate_fn=utils.collate_fn,
                                 num_workers=args.num_workers)

    if args.dataset_file == "coco_panoptic":
        # We also evaluate AP during panoptic training, on original coco DS
        coco_val = datasets.coco.build("val", args)
        base_ds = get_coco_api_from_dataset(coco_val)
    elif args.dataset_file == "coco":
        base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model'])

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    # if args.eval:
    #     if 'coco' in args.dataset_file:
    #         test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
    #                                             data_loader_val, base_ds, device, args.output_dir)
    #         if args.output_dir:
    #             utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
    #     elif 'anet' == args.dataset_file:
    #         evaluate3d(model, postprocessors, data_loader_val, device, epoch=0)
    #     return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch,
                                      args.clip_max_norm)
        lr_scheduler.step()

        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
                checkpoint_paths.append(output_dir /
                                        f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)

        if epoch % args.eval_freq == 0:
            if 'coco' in args.dataset_file:
                test_stats, coco_evaluator = evaluate(model, criterion,
                                                      postprocessors,
                                                      data_loader_val, base_ds,
                                                      device, args.output_dir)
            elif 'anet' == args.dataset_file:
                evaluate3d(model, postprocessors, data_loader_val, device,
                           epoch)
Beispiel #21
0
class ForcesTrainer(BaseTrainer):
    """
    Trainer class for the Structure to Energy & Force (S2EF) and Initial State to
    Relaxed State (IS2RS) tasks.

    .. note::

        Examples of configurations for task, model, dataset and optimizer
        can be found in `configs/ocp_s2ef <https://github.com/Open-Catalyst-Project/baselines/tree/master/configs/ocp_is2re/>`_
        and `configs/ocp_is2rs <https://github.com/Open-Catalyst-Project/baselines/tree/master/configs/ocp_is2rs/>`_.

    Args:
        task (dict): Task configuration.
        model (dict): Model configuration.
        dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset.
        optimizer (dict): Optimizer configuration.
        identifier (str): Experiment identifier that is appended to log directory.
        run_dir (str, optional): Path to the run directory where logs are to be saved.
            (default: :obj:`None`)
        is_debug (bool, optional): Run in debug mode.
            (default: :obj:`False`)
        is_vis (bool, optional): Run in debug mode.
            (default: :obj:`False`)
        is_hpo (bool, optional): Run hyperparameter optimization with Ray Tune.
            (default: :obj:`False`)
        print_every (int, optional): Frequency of printing logs.
            (default: :obj:`100`)
        seed (int, optional): Random number seed.
            (default: :obj:`None`)
        logger (str, optional): Type of logger to be used.
            (default: :obj:`tensorboard`)
        local_rank (int, optional): Local rank of the process, only applicable for distributed training.
            (default: :obj:`0`)
        amp (bool, optional): Run using automatic mixed precision.
            (default: :obj:`False`)
    """
    def __init__(
        self,
        task,
        model,
        dataset,
        optimizer,
        identifier,
        run_dir=None,
        is_debug=False,
        is_vis=False,
        is_hpo=False,
        print_every=100,
        seed=None,
        logger="tensorboard",
        local_rank=0,
        amp=False,
        cpu=False,
    ):
        super().__init__(
            task=task,
            model=model,
            dataset=dataset,
            optimizer=optimizer,
            identifier=identifier,
            run_dir=run_dir,
            is_debug=is_debug,
            is_vis=is_vis,
            is_hpo=is_hpo,
            print_every=print_every,
            seed=seed,
            logger=logger,
            local_rank=local_rank,
            amp=amp,
            cpu=cpu,
            name="s2ef",
        )

    def load_task(self):
        print("### Loading dataset: {}".format(self.config["task"]["dataset"]))

        self.parallel_collater = ParallelCollater(
            1 if not self.cpu else 0,
            self.config["model_attributes"].get("otf_graph", False),
        )
        if self.config["task"]["dataset"] == "trajectory_lmdb":
            self.train_dataset = registry.get_dataset_class(
                self.config["task"]["dataset"])(self.config["dataset"])

            self.train_sampler = DistributedSampler(
                self.train_dataset,
                num_replicas=distutils.get_world_size(),
                rank=distutils.get_rank(),
                shuffle=True,
            )

            self.train_loader = DataLoader(
                self.train_dataset,
                batch_size=self.config["optim"]["batch_size"],
                collate_fn=self.parallel_collater,
                num_workers=self.config["optim"]["num_workers"],
                pin_memory=True,
                sampler=self.train_sampler,
            )

            self.val_loader = self.test_loader = None
            self.val_sampler = self.test_sampler = None

            if "val_dataset" in self.config:
                self.val_dataset = registry.get_dataset_class(
                    self.config["task"]["dataset"])(self.config["val_dataset"])
                self.val_sampler = DistributedSampler(
                    self.val_dataset,
                    num_replicas=distutils.get_world_size(),
                    rank=distutils.get_rank(),
                    shuffle=False,
                )
                self.val_loader = DataLoader(
                    self.val_dataset,
                    self.config["optim"].get("eval_batch_size", 64),
                    collate_fn=self.parallel_collater,
                    num_workers=self.config["optim"]["num_workers"],
                    pin_memory=True,
                    sampler=self.val_sampler,
                )
            if "test_dataset" in self.config:
                self.test_dataset = registry.get_dataset_class(
                    self.config["task"]["dataset"])(
                        self.config["test_dataset"])
                self.test_sampler = DistributedSampler(
                    self.test_dataset,
                    num_replicas=distutils.get_world_size(),
                    rank=distutils.get_rank(),
                    shuffle=False,
                )
                self.test_loader = DataLoader(
                    self.test_dataset,
                    self.config["optim"].get("eval_batch_size", 64),
                    collate_fn=self.parallel_collater,
                    num_workers=self.config["optim"]["num_workers"],
                    pin_memory=True,
                    sampler=self.test_sampler,
                )

        if "relax_dataset" in self.config["task"]:
            assert os.path.isfile(self.config["task"]["relax_dataset"]["src"])

            self.relax_dataset = registry.get_dataset_class(
                "single_point_lmdb")(self.config["task"]["relax_dataset"])

            self.relax_sampler = DistributedSampler(
                self.relax_dataset,
                num_replicas=distutils.get_world_size(),
                rank=distutils.get_rank(),
                shuffle=False,
            )
            self.relax_loader = DataLoader(
                self.relax_dataset,
                batch_size=self.config["optim"].get("eval_batch_size", 64),
                collate_fn=self.parallel_collater,
                num_workers=self.config["optim"]["num_workers"],
                pin_memory=True,
                sampler=self.relax_sampler,
            )

        self.num_targets = 1

        # Normalizer for the dataset.
        # Compute mean, std of training set labels.
        self.normalizers = {}
        if self.config["dataset"].get("normalize_labels", False):
            if "target_mean" in self.config["dataset"]:
                self.normalizers["target"] = Normalizer(
                    mean=self.config["dataset"]["target_mean"],
                    std=self.config["dataset"]["target_std"],
                    device=self.device,
                )
            else:
                self.normalizers["target"] = Normalizer(
                    tensor=self.train_loader.dataset.data.y[
                        self.train_loader.dataset.__indices__],
                    device=self.device,
                )

        # If we're computing gradients wrt input, set mean of normalizer to 0 --
        # since it is lost when compute dy / dx -- and std to forward target std
        if self.config["model_attributes"].get("regress_forces", True):
            if self.config["dataset"].get("normalize_labels", False):
                if "grad_target_mean" in self.config["dataset"]:
                    self.normalizers["grad_target"] = Normalizer(
                        mean=self.config["dataset"]["grad_target_mean"],
                        std=self.config["dataset"]["grad_target_std"],
                        device=self.device,
                    )
                else:
                    self.normalizers["grad_target"] = Normalizer(
                        tensor=self.train_loader.dataset.data.y[
                            self.train_loader.dataset.__indices__],
                        device=self.device,
                    )
                    self.normalizers["grad_target"].mean.fill_(0)

        if (self.is_vis and self.config["task"]["dataset"] != "qm9"
                and distutils.is_master()):
            # Plot label distribution.
            plots = [
                plot_histogram(
                    self.train_loader.dataset.data.y.tolist(),
                    xlabel="{}/raw".format(self.config["task"]["labels"][0]),
                    ylabel="# Examples",
                    title="Split: train",
                ),
                plot_histogram(
                    self.val_loader.dataset.data.y.tolist(),
                    xlabel="{}/raw".format(self.config["task"]["labels"][0]),
                    ylabel="# Examples",
                    title="Split: val",
                ),
                plot_histogram(
                    self.test_loader.dataset.data.y.tolist(),
                    xlabel="{}/raw".format(self.config["task"]["labels"][0]),
                    ylabel="# Examples",
                    title="Split: test",
                ),
            ]
            self.logger.log_plots(plots)

    # Takes in a new data source and generates predictions on it.
    @torch.no_grad()
    def predict(self,
                data_loader,
                per_image=True,
                results_file=None,
                disable_tqdm=True):
        if distutils.is_master() and not disable_tqdm:
            print("### Predicting on test.")
        assert isinstance(
            data_loader,
            (
                torch.utils.data.dataloader.DataLoader,
                torch_geometric.data.Batch,
            ),
        )
        rank = distutils.get_rank()

        if isinstance(data_loader, torch_geometric.data.Batch):
            data_loader = [[data_loader]]

        self.model.eval()
        if self.normalizers is not None and "target" in self.normalizers:
            self.normalizers["target"].to(self.device)
            self.normalizers["grad_target"].to(self.device)

        predictions = {"id": [], "energy": [], "forces": [], "chunk_idx": []}

        for i, batch_list in tqdm(
                enumerate(data_loader),
                total=len(data_loader),
                position=rank,
                desc="device {}".format(rank),
                disable=disable_tqdm,
        ):
            with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                out = self._forward(batch_list)

            if self.normalizers is not None and "target" in self.normalizers:
                out["energy"] = self.normalizers["target"].denorm(
                    out["energy"])
                out["forces"] = self.normalizers["grad_target"].denorm(
                    out["forces"])
            if per_image:
                systemids = [
                    str(i) + "_" + str(j) for i, j in zip(
                        batch_list[0].sid.tolist(), batch_list[0].fid.tolist())
                ]
                predictions["id"].extend(systemids)
                predictions["energy"].extend(out["energy"].to(
                    torch.float16).tolist())
                batch_natoms = torch.cat(
                    [batch.natoms for batch in batch_list])
                batch_fixed = torch.cat([batch.fixed for batch in batch_list])
                forces = out["forces"].cpu().detach().to(torch.float16)
                per_image_forces = torch.split(forces, batch_natoms.tolist())
                per_image_forces = [
                    force.numpy() for force in per_image_forces
                ]
                # evalAI only requires forces on free atoms
                if results_file is not None:
                    _per_image_fixed = torch.split(batch_fixed,
                                                   batch_natoms.tolist())
                    _per_image_free_forces = [
                        force[(fixed == 0).tolist()] for force, fixed in zip(
                            per_image_forces, _per_image_fixed)
                    ]
                    _chunk_idx = np.array([
                        free_force.shape[0]
                        for free_force in _per_image_free_forces
                    ])
                    per_image_forces = _per_image_free_forces
                    predictions["chunk_idx"].extend(_chunk_idx)
                predictions["forces"].extend(per_image_forces)
            else:
                predictions["energy"] = out["energy"].detach()
                predictions["forces"] = out["forces"].detach()
                return predictions

        predictions["forces"] = np.array(predictions["forces"])
        predictions["chunk_idx"] = np.array(predictions["chunk_idx"])
        predictions["energy"] = np.array(predictions["energy"])
        predictions["id"] = np.array(predictions["id"])
        self.save_results(predictions,
                          results_file,
                          keys=["energy", "forces", "chunk_idx"])
        return predictions

    def train(self):
        eval_every = self.config["optim"].get("eval_every",
                                              len(self.train_loader))
        primary_metric = self.config["task"].get(
            "primary_metric", self.evaluator.task_primary_metric[self.name])
        self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0
        iters = 0
        self.metrics = {}

        start_epoch = self.start_step // len(self.train_loader)
        for epoch in range(start_epoch, self.config["optim"]["max_epochs"]):
            self.train_sampler.set_epoch(epoch)
            skip_steps = 0
            if epoch == start_epoch and start_epoch > 0:
                skip_steps = start_epoch % len(self.train_loader)
            train_loader_iter = iter(self.train_loader)

            for i in range(skip_steps, len(self.train_loader)):
                self.model.train()
                current_epoch = epoch + (i + 1) / len(self.train_loader)
                current_step = epoch * len(self.train_loader) + (i + 1)

                # Get a batch.
                batch = next(train_loader_iter)

                # Forward, loss, backward.
                with torch.cuda.amp.autocast(enabled=self.scaler is not None):
                    out = self._forward(batch)
                    loss = self._compute_loss(out, batch)
                loss = self.scaler.scale(loss) if self.scaler else loss
                self._backward(loss)
                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    self.metrics,
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Log metrics.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update({
                    "lr": self.scheduler.get_lr(),
                    "epoch": current_epoch,
                    "step": current_step,
                })
                if (current_step % self.config["cmd"]["print_every"] == 0
                        and distutils.is_master() and not self.is_hpo):
                    log_str = [
                        "{}: {:.2e}".format(k, v) for k, v in log_dict.items()
                    ]
                    print(", ".join(log_str))
                    self.metrics = {}

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=current_step,
                        split="train",
                    )

                iters += 1

                # Evaluate on val set every `eval_every` iterations.
                if iters % eval_every == 0:
                    if self.val_loader is not None:
                        val_metrics = self.validate(
                            split="val",
                            epoch=epoch - 1 + (i + 1) / len(self.train_loader),
                        )
                        if ("mae" in primary_metric
                                and val_metrics[primary_metric]["metric"] <
                                self.best_val_metric) or (
                                    val_metrics[primary_metric]["metric"] >
                                    self.best_val_metric):
                            self.best_val_metric = val_metrics[primary_metric][
                                "metric"]
                            self.save(current_epoch, current_step, val_metrics)
                            if self.test_loader is not None:
                                self.predict(
                                    self.test_loader,
                                    results_file="predictions",
                                    disable_tqdm=False,
                                )

                        if self.is_hpo:
                            self.hpo_update(
                                current_epoch,
                                current_step,
                                self.metrics,
                                val_metrics,
                            )

                    else:
                        self.save(current_epoch, current_step, self.metrics)

                if self.scheduler.scheduler_type == "ReduceLROnPlateau":
                    if iters % eval_every == 0:
                        self.scheduler.step(
                            metrics=val_metrics[primary_metric]["metric"], )
                else:
                    self.scheduler.step()

            torch.cuda.empty_cache()

        self.train_dataset.close_db()
        if "val_dataset" in self.config:
            self.val_dataset.close_db()
        if "test_dataset" in self.config:
            self.test_dataset.close_db()

    def _forward(self, batch_list):
        # forward pass.
        if self.config["model_attributes"].get("regress_forces", True):
            out_energy, out_forces = self.model(batch_list)
        else:
            out_energy = self.model(batch_list)

        if out_energy.shape[-1] == 1:
            out_energy = out_energy.view(-1)

        out = {
            "energy": out_energy,
        }

        if self.config["model_attributes"].get("regress_forces", True):
            out["forces"] = out_forces

        return out

    def _compute_loss(self, out, batch_list):
        loss = []

        # Energy loss.
        energy_target = torch.cat(
            [batch.y.to(self.device) for batch in batch_list], dim=0)
        if self.config["dataset"].get("normalize_labels", False):
            energy_target = self.normalizers["target"].norm(energy_target)
        energy_mult = self.config["optim"].get("energy_coefficient", 1)
        loss.append(energy_mult * self.criterion(out["energy"], energy_target))

        # Force loss.
        if self.config["model_attributes"].get("regress_forces", True):
            force_target = torch.cat(
                [batch.force.to(self.device) for batch in batch_list], dim=0)
            if self.config["dataset"].get("normalize_labels", False):
                force_target = self.normalizers["grad_target"].norm(
                    force_target)

            tag_specific_weights = self.config["task"].get(
                "tag_specific_weights", [])
            if tag_specific_weights != []:
                # handle tag specific weights as introduced in forcenet
                assert len(tag_specific_weights) == 3

                batch_tags = torch.cat(
                    [
                        batch.tags.float().to(self.device)
                        for batch in batch_list
                    ],
                    dim=0,
                )
                weight = torch.zeros_like(batch_tags)
                weight[batch_tags == 0] = tag_specific_weights[0]
                weight[batch_tags == 1] = tag_specific_weights[1]
                weight[batch_tags == 2] = tag_specific_weights[2]

                loss_force_list = torch.abs(out["forces"] - force_target)
                train_loss_force_unnormalized = torch.sum(loss_force_list *
                                                          weight.view(-1, 1))
                train_loss_force_normalizer = 3.0 * weight.sum()

                # add up normalizer to obtain global normalizer
                distutils.all_reduce(train_loss_force_normalizer)

                # perform loss normalization before backprop
                train_loss_force_normalized = train_loss_force_unnormalized * (
                    distutils.get_world_size() / train_loss_force_normalizer)
                loss.append(train_loss_force_normalized)

            else:
                # Force coefficient = 30 has been working well for us.
                force_mult = self.config["optim"].get("force_coefficient", 30)
                if self.config["task"].get("train_on_free_atoms", False):
                    fixed = torch.cat(
                        [batch.fixed.to(self.device) for batch in batch_list])
                    mask = fixed == 0
                    loss.append(force_mult * self.criterion(
                        out["forces"][mask], force_target[mask]))
                else:
                    loss.append(force_mult *
                                self.criterion(out["forces"], force_target))
        # Sanity check to make sure the compute graph is correct.
        for lc in loss:
            assert hasattr(lc, "grad_fn")

        loss = sum(loss)
        return loss

    def _compute_metrics(self, out, batch_list, evaluator, metrics={}):
        natoms = torch.cat(
            [batch.natoms.to(self.device) for batch in batch_list], dim=0)

        target = {
            "energy":
            torch.cat([batch.y.to(self.device) for batch in batch_list],
                      dim=0),
            "forces":
            torch.cat([batch.force.to(self.device) for batch in batch_list],
                      dim=0),
            "natoms":
            natoms,
        }

        out["natoms"] = natoms

        if self.config["task"].get("eval_on_free_atoms", True):
            fixed = torch.cat(
                [batch.fixed.to(self.device) for batch in batch_list])
            mask = fixed == 0
            out["forces"] = out["forces"][mask]
            target["forces"] = target["forces"][mask]

            s_idx = 0
            natoms_free = []
            for natoms in target["natoms"]:
                natoms_free.append(
                    torch.sum(mask[s_idx:s_idx + natoms]).item())
                s_idx += natoms
            target["natoms"] = torch.LongTensor(natoms_free).to(self.device)
            out["natoms"] = torch.LongTensor(natoms_free).to(self.device)

        if self.config["dataset"].get("normalize_labels", False):
            out["energy"] = self.normalizers["target"].denorm(out["energy"])
            out["forces"] = self.normalizers["grad_target"].denorm(
                out["forces"])

        metrics = evaluator.eval(out, target, prev_metrics=metrics)
        return metrics

    def run_relaxations(self, split="val", epoch=None):
        print("### Running ML-relaxations")
        self.model.eval()

        evaluator, metrics = Evaluator(task="is2rs"), {}

        if hasattr(self.relax_dataset[0], "pos_relaxed") and hasattr(
                self.relax_dataset[0], "y_relaxed"):
            split = "val"
        else:
            split = "test"

        ids = []
        relaxed_positions = []
        chunk_idx = []
        for i, batch in tqdm(enumerate(self.relax_loader),
                             total=len(self.relax_loader)):
            relaxed_batch = ml_relax(
                batch=batch,
                model=self,
                steps=self.config["task"].get("relaxation_steps", 200),
                fmax=self.config["task"].get("relaxation_fmax", 0.0),
                relax_opt=self.config["task"]["relax_opt"],
                device=self.device,
                transform=None,
            )

            if self.config["task"].get("write_pos", False):
                systemids = [str(i) for i in relaxed_batch.sid.tolist()]
                natoms = relaxed_batch.natoms.tolist()
                positions = torch.split(relaxed_batch.pos, natoms)
                batch_relaxed_positions = [pos.tolist() for pos in positions]

                relaxed_positions += batch_relaxed_positions
                chunk_idx += natoms
                ids += systemids

            if split == "val":
                mask = relaxed_batch.fixed == 0
                s_idx = 0
                natoms_free = []
                for natoms in relaxed_batch.natoms:
                    natoms_free.append(
                        torch.sum(mask[s_idx:s_idx + natoms]).item())
                    s_idx += natoms

                target = {
                    "energy": relaxed_batch.y_relaxed,
                    "positions": relaxed_batch.pos_relaxed[mask],
                    "cell": relaxed_batch.cell,
                    "pbc": torch.tensor([True, True, True]),
                    "natoms": torch.LongTensor(natoms_free),
                }

                prediction = {
                    "energy": relaxed_batch.y,
                    "positions": relaxed_batch.pos[mask],
                    "cell": relaxed_batch.cell,
                    "pbc": torch.tensor([True, True, True]),
                    "natoms": torch.LongTensor(natoms_free),
                }

                metrics = evaluator.eval(prediction, target, metrics)

        if self.config["task"].get("write_pos", False):
            rank = distutils.get_rank()
            pos_filename = os.path.join(self.config["cmd"]["results_dir"],
                                        f"relaxed_pos_{rank}.npz")
            np.savez_compressed(
                pos_filename,
                ids=ids,
                pos=np.array(relaxed_positions, dtype=object),
                chunk_idx=chunk_idx,
            )

            distutils.synchronize()
            if distutils.is_master():
                gather_results = defaultdict(list)
                full_path = os.path.join(
                    self.config["cmd"]["results_dir"],
                    "relaxed_positions.npz",
                )

                for i in range(distutils.get_world_size()):
                    rank_path = os.path.join(
                        self.config["cmd"]["results_dir"],
                        f"relaxed_pos_{i}.npz",
                    )
                    rank_results = np.load(rank_path, allow_pickle=True)
                    gather_results["ids"].extend(rank_results["ids"])
                    gather_results["pos"].extend(rank_results["pos"])
                    gather_results["chunk_idx"].extend(
                        rank_results["chunk_idx"])
                    os.remove(rank_path)

                # Because of how distributed sampler works, some system ids
                # might be repeated to make no. of samples even across GPUs.
                _, idx = np.unique(gather_results["ids"], return_index=True)
                gather_results["ids"] = np.array(gather_results["ids"])[idx]
                gather_results["pos"] = np.concatenate(
                    np.array(gather_results["pos"])[idx])
                gather_results["chunk_idx"] = np.cumsum(
                    np.array(gather_results["chunk_idx"])[idx]
                )[:-1]  # np.split does not need last idx, assumes n-1:end

                print(f"Writing results to {full_path}")
                np.savez_compressed(full_path, **gather_results)

        if split == "val":
            aggregated_metrics = {}
            for k in metrics:
                aggregated_metrics[k] = {
                    "total":
                    distutils.all_reduce(metrics[k]["total"],
                                         average=False,
                                         device=self.device),
                    "numel":
                    distutils.all_reduce(metrics[k]["numel"],
                                         average=False,
                                         device=self.device),
                }
                aggregated_metrics[k]["metric"] = (
                    aggregated_metrics[k]["total"] /
                    aggregated_metrics[k]["numel"])
            metrics = aggregated_metrics

            # Make plots.
            log_dict = {k: metrics[k]["metric"] for k in metrics}
            if self.logger is not None and epoch is not None:
                self.logger.log(
                    log_dict,
                    step=(epoch + 1) * len(self.train_loader),
                    split=split,
                )

            if distutils.is_master():
                print(metrics)
Beispiel #22
0
def main(args):
    utils.init_distributed_mode(args)

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = build_dataset(image_set='val', args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn,
                                   num_workers=args.num_workers)
    data_loader_val = DataLoader(dataset_val,
                                 batch_size=1,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 collate_fn=utils.collate_fn,
                                 num_workers=args.num_workers)

    if args.dataset_file == "coco_panoptic":
        # We also evaluate AP during panoptic training, on original coco DS
        coco_val = datasets.coco.build("val", args)
        base_ds = get_coco_api_from_dataset(coco_val)
    else:
        base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        io.load_frozen(args, model_without_ddp)

    output_dir = Path(args.output_dir)
    if args.resume:
        io.resume(args, model_without_ddp, optimizer, lr_scheduler)

    elif args.finetune:
        io.finetune(args, model_without_ddp)

    if args.eval:

        if args.output_dir and utils.is_main_process():
            io.init_wandb(args.dataset_file + "-detr-eval",
                          model,
                          args,
                          n_parameters=n_parameters)

        test_stats, evaluator = evaluate(model, criterion, postprocessors,
                                         data_loader_val, base_ds, device,
                                         args.output_dir)
        if args.output_dir:
            io.save_on_master(evaluator.coco_eval["bbox"].eval,
                              output_dir / "eval.pth")
        return

    print("Start training")
    start_time = time.time()

    if args.output_dir and utils.is_main_process():
        io.init_wandb(args.dataset_file + "-detr",
                      model,
                      args,
                      n_parameters=n_parameters)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch,
                                      args.clip_max_norm)
        lr_scheduler.step()
        if args.output_dir:
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
                io.save_checkpoint(args, model_without_ddp, optimizer,
                                   lr_scheduler, epoch)

        test_stats, evaluator = evaluate(model, criterion, postprocessors,
                                         data_loader_val, base_ds, device,
                                         args.output_dir, epoch)

        if utils.is_main_process() and args.output_dir:
            io.log_wandb(train_stats, test_stats)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))

    # save final model
    if utils.is_main_process() and args.output_dir:
        io.save_on_master(model_without_ddp, output_dir / "model_final.pth")

    print('Training time {}'.format(total_time_str))
Beispiel #23
0
def main(args):
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print('Loading data')
    dataset_train = build_dataset(args.train_set, args.dataset_year, args)
    dataset_val = build_dataset(args.val_set, args.dataset_year, args)
    base_ds = get_coco_api_from_dataset(dataset_val)

    print('Creating data loaders')
    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(
        sampler_train,
        args.batch_size,
        drop_last=True,
    )

    data_loader_train = DataLoader(
        dataset_train,
        batch_sampler=batch_sampler_train,
        collate_fn=utils.collate_fn,
        num_workers=args.num_workers,
    )
    data_loader_val = DataLoader(
        dataset_val,
        args.batch_size,
        sampler=sampler_val,
        drop_last=False,
        collate_fn=utils.collate_fn,
        num_workers=args.num_workers,
    )

    print('Creating model, always set args.return_criterion be True')
    args.return_criterion = True
    model = yolov5s(num_classes=args.num_classes)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu],
        )
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params,
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    if args.lr_scheduler == 'cosine':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.t_max)
    elif args.lr_scheduler == 'multi-step':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=args.lr_steps,
            gamma=args.lr_gamma,
        )
    else:
        raise ValueError(f'scheduler {args.lr_scheduler} not supported')

    output_dir = Path(args.output_dir)
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, data_loader_val, base_ds, device)
        return

    print('Start training')
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_one_epoch(model, optimizer, data_loader_train, device, epoch,
                        args.print_freq)

        lr_scheduler.step()
        if args.output_dir:
            utils.save_on_master(
                {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'args': args,
                    'epoch': epoch,
                },
                output_dir.joinpath(f'model_{epoch}.pth'),
            )

        # evaluate after every epoch
        # evaluate(model, criterion, data_loader_val, device=device)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f'Training time {total_time_str}')
Beispiel #24
0
def main(args):
    # Init distributed mode
    dist.init_distributed_mode(args)

    # Update dataset specific configs
    if args.dataset_config is not None:
        # https://stackoverflow.com/a/16878364
        d = vars(args)
        with open(args.dataset_config, "r") as f:
            cfg = json.load(f)
        d.update(cfg)

    print("git:\n  {}\n".format(utils.get_sha()))

    # Segmentation related
    if args.mask_model != "none":
        args.masks = True
    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"

    print(args)

    device = torch.device(args.device)
    output_dir = Path(args.output_dir)

    # fix the seed for reproducibility
    seed = args.seed + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.set_deterministic(True)

    # Build the model
    model, criterion, contrastive_criterion, qa_criterion, weight_dict = build_model(
        args)
    model.to(device)

    assert (
        criterion is not None or qa_criterion is not None
    ), "Error: should train either detection or question answering (or both)"

    # Get a copy of the model for exponential moving averaged version of the model
    model_ema = deepcopy(model) if args.ema else None
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print("number of params:", n_parameters)

    # Set up optimizers
    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and "text_encoder" not in n
                and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "text_encoder" in n and p.requires_grad
            ],
            "lr":
            args.text_encoder_lr,
        },
    ]
    if args.optimizer == "sgd":
        optimizer = torch.optim.SGD(param_dicts,
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.optimizer in ["adam", "adamw"]:
        optimizer = torch.optim.AdamW(param_dicts,
                                      lr=args.lr,
                                      weight_decay=args.weight_decay)
    else:
        raise RuntimeError(f"Unsupported optimizer {args.optimizer}")

    # Train dataset
    if len(args.combine_datasets) == 0 and not args.eval:
        raise RuntimeError("Please provide at least one training dataset")

    dataset_train, sampler_train, data_loader_train = None, None, None
    if not args.eval:
        dataset_train = ConcatDataset([
            build_dataset(name, image_set="train", args=args)
            for name in args.combine_datasets
        ])

        # To handle very big datasets, we chunk it into smaller parts.
        if args.epoch_chunks > 0:
            print(
                "Splitting the training set into {args.epoch_chunks} of size approximately "
                f" {len(dataset_train) // args.epoch_chunks}")
            chunks = torch.chunk(torch.arange(len(dataset_train)),
                                 args.epoch_chunks)
            datasets = [
                torch.utils.data.Subset(dataset_train, chunk.tolist())
                for chunk in chunks
            ]
            if args.distributed:
                samplers_train = [DistributedSampler(ds) for ds in datasets]
            else:
                samplers_train = [
                    torch.utils.data.RandomSampler(ds) for ds in datasets
                ]

            batch_samplers_train = [
                torch.utils.data.BatchSampler(sampler_train,
                                              args.batch_size,
                                              drop_last=True)
                for sampler_train in samplers_train
            ]
            assert len(batch_samplers_train) == len(datasets)
            data_loaders_train = [
                DataLoader(
                    ds,
                    batch_sampler=batch_sampler_train,
                    collate_fn=partial(utils.collate_fn, False),
                    num_workers=args.num_workers,
                ) for ds, batch_sampler_train in zip(datasets,
                                                     batch_samplers_train)
            ]
        else:
            if args.distributed:
                sampler_train = DistributedSampler(dataset_train)
            else:
                sampler_train = torch.utils.data.RandomSampler(dataset_train)

            batch_sampler_train = torch.utils.data.BatchSampler(
                sampler_train, args.batch_size, drop_last=True)
            data_loader_train = DataLoader(
                dataset_train,
                batch_sampler=batch_sampler_train,
                collate_fn=partial(utils.collate_fn, False),
                num_workers=args.num_workers,
            )

    # Val dataset
    if len(args.combine_datasets_val) == 0:
        raise RuntimeError("Please provide at leas one validation dataset")

    Val_all = namedtuple(typename="val_data",
                         field_names=[
                             "dataset_name", "dataloader", "base_ds",
                             "evaluator_list"
                         ])

    val_tuples = []
    for dset_name in args.combine_datasets_val:
        dset = build_dataset(dset_name, image_set="val", args=args)
        sampler = (DistributedSampler(dset, shuffle=False) if args.distributed
                   else torch.utils.data.SequentialSampler(dset))
        dataloader = DataLoader(
            dset,
            args.batch_size,
            sampler=sampler,
            drop_last=False,
            collate_fn=partial(utils.collate_fn, False),
            num_workers=args.num_workers,
        )
        base_ds = get_coco_api_from_dataset(dset)
        val_tuples.append(
            Val_all(dataset_name=dset_name,
                    dataloader=dataloader,
                    base_ds=base_ds,
                    evaluator_list=None))

    if args.frozen_weights is not None:
        if args.resume.startswith("https"):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location="cpu",
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location="cpu")
        if "model_ema" in checkpoint and checkpoint["model_ema"] is not None:
            model_without_ddp.detr.load_state_dict(checkpoint["model_ema"],
                                                   strict=False)
        else:
            model_without_ddp.detr.load_state_dict(checkpoint["model"],
                                                   strict=False)

        if args.ema:
            model_ema = deepcopy(model_without_ddp)

    # Used for loading weights from another model and starting a training from scratch. Especially useful if
    # loading into a model with different functionality.
    if args.load:
        print("loading from", args.load)
        checkpoint = torch.load(args.load, map_location="cpu")
        if "model_ema" in checkpoint:
            model_without_ddp.load_state_dict(checkpoint["model_ema"],
                                              strict=False)
        else:
            model_without_ddp.load_state_dict(checkpoint["model"],
                                              strict=False)

        if args.ema:
            model_ema = deepcopy(model_without_ddp)

    # Used for resuming training from the checkpoint of a model. Used when training times-out or is pre-empted.
    if args.resume:
        if args.resume.startswith("https"):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location="cpu",
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        if not args.eval and "optimizer" in checkpoint and "epoch" in checkpoint:
            optimizer.load_state_dict(checkpoint["optimizer"])
            args.start_epoch = checkpoint["epoch"] + 1
        if args.ema:
            if "model_ema" not in checkpoint:
                print(
                    "WARNING: ema model not found in checkpoint, resetting to current model"
                )
                model_ema = deepcopy(model_without_ddp)
            else:
                model_ema.load_state_dict(checkpoint["model_ema"])

    def build_evaluator_list(base_ds, dataset_name):
        """Helper function to build the list of evaluators for a given dataset"""
        evaluator_list = []
        if args.no_detection:
            return evaluator_list
        iou_types = ["bbox"]
        if args.masks:
            iou_types.append("segm")

        evaluator_list.append(
            CocoEvaluator(base_ds, tuple(iou_types), useCats=False))
        if "refexp" in dataset_name:
            evaluator_list.append(RefExpEvaluator(base_ds, ("bbox")))
        if "clevrref" in dataset_name:
            evaluator_list.append(ClevrRefEvaluator(base_ds, ("bbox")))
        if "flickr" in dataset_name:
            evaluator_list.append(
                FlickrEvaluator(
                    args.flickr_dataset_path,
                    subset="test" if args.test else "val",
                    merge_boxes=args.GT_type == "merged",
                ))
        if "phrasecut" in dataset_name:
            evaluator_list.append(
                PhrasecutEvaluator(
                    "test" if args.test else "miniv",
                    ann_folder=args.phrasecut_orig_ann_path,
                    output_dir=os.path.join(output_dir, "phrasecut_eval"),
                    eval_mask=args.masks,
                ))
        return evaluator_list

    # Runs only evaluation, by default on the validation set unless --test is passed.
    if args.eval:
        test_stats = {}
        test_model = model_ema if model_ema is not None else model
        for i, item in enumerate(val_tuples):
            evaluator_list = build_evaluator_list(item.base_ds,
                                                  item.dataset_name)
            postprocessors = build_postprocessors(args, item.dataset_name)
            item = item._replace(evaluator_list=evaluator_list)
            print(f"Evaluating {item.dataset_name}")
            curr_test_stats = evaluate(
                model=test_model,
                criterion=criterion,
                contrastive_criterion=contrastive_criterion,
                qa_criterion=qa_criterion,
                postprocessors=postprocessors,
                weight_dict=weight_dict,
                data_loader=item.dataloader,
                evaluator_list=item.evaluator_list,
                device=device,
                args=args,
            )
            test_stats.update({
                item.dataset_name + "_" + k: v
                for k, v in curr_test_stats.items()
            })

        log_stats = {
            **{f"test_{k}": v
               for k, v in test_stats.items()},
            "n_parameters": n_parameters,
        }
        print(log_stats)
        return

    # Runs training and evaluates after every --eval_skip epochs
    print("Start training")
    start_time = time.time()
    best_metric = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.epoch_chunks > 0:
            sampler_train = samplers_train[epoch % len(samplers_train)]
            data_loader_train = data_loaders_train[epoch %
                                                   len(data_loaders_train)]
            print(
                f"Starting epoch {epoch // len(data_loaders_train)}, sub_epoch {epoch % len(data_loaders_train)}"
            )
        else:
            print(f"Starting epoch {epoch}")
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(
            model=model,
            criterion=criterion,
            contrastive_criterion=contrastive_criterion,
            qa_criterion=qa_criterion,
            data_loader=data_loader_train,
            weight_dict=weight_dict,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            args=args,
            max_norm=args.clip_max_norm,
            model_ema=model_ema,
        )
        if args.output_dir:
            checkpoint_paths = [output_dir / "checkpoint.pth"]
            # extra checkpoint before LR drop and every 2 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 2 == 0:
                checkpoint_paths.append(output_dir /
                                        f"checkpoint{epoch:04}.pth")
            for checkpoint_path in checkpoint_paths:
                dist.save_on_master(
                    {
                        "model": model_without_ddp.state_dict(),
                        "model_ema":
                        model_ema.state_dict() if args.ema else None,
                        "optimizer": optimizer.state_dict(),
                        "epoch": epoch,
                        "args": args,
                    },
                    checkpoint_path,
                )

        if epoch % args.eval_skip == 0:
            test_stats = {}
            test_model = model_ema if model_ema is not None else model
            for i, item in enumerate(val_tuples):
                evaluator_list = build_evaluator_list(item.base_ds,
                                                      item.dataset_name)
                item = item._replace(evaluator_list=evaluator_list)
                postprocessors = build_postprocessors(args, item.dataset_name)
                print(f"Evaluating {item.dataset_name}")
                curr_test_stats = evaluate(
                    model=test_model,
                    criterion=criterion,
                    contrastive_criterion=contrastive_criterion,
                    qa_criterion=qa_criterion,
                    postprocessors=postprocessors,
                    weight_dict=weight_dict,
                    data_loader=item.dataloader,
                    evaluator_list=item.evaluator_list,
                    device=device,
                    args=args,
                )
                test_stats.update({
                    item.dataset_name + "_" + k: v
                    for k, v in curr_test_stats.items()
                })
        else:
            test_stats = {}

        log_stats = {
            **{f"train_{k}": v
               for k, v in train_stats.items()},
            **{f"test_{k}": v
               for k, v in test_stats.items()},
            "epoch": epoch,
            "n_parameters": n_parameters,
        }

        if args.output_dir and dist.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

        if epoch % args.eval_skip == 0:
            if args.do_qa:
                metric = test_stats["gqa_accuracy_answer_total_unscaled"]
            else:
                metric = np.mean([
                    v[1] for k, v in test_stats.items()
                    if "coco_eval_bbox" in k
                ])

            if args.output_dir and metric > best_metric:
                best_metric = metric
                checkpoint_paths = [output_dir / "BEST_checkpoint.pth"]
                # extra checkpoint before LR drop and every 100 epochs
                for checkpoint_path in checkpoint_paths:
                    dist.save_on_master(
                        {
                            "model": model_without_ddp.state_dict(),
                            "optimizer": optimizer.state_dict(),
                            "epoch": epoch,
                            "args": args,
                        },
                        checkpoint_path,
                    )

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Training time {}".format(total_time_str))
Beispiel #25
0
class Brain:
    r"""Brain class abstracts away the details of data loops.

    The primary purpose of the `Brain` class is the implementation of
    the ``fit()`` method, which iterates epochs and datasets for the
    purpose of "fitting" a set of modules to a set of data.

    In order to use the ``fit()`` method, one should sub-class the ``Brain``
    class and override any methods for which the default behavior does not
    match the use case. For a simple use case (e.g., training a single model
    with a single dataset) the only methods that need to be overridden are:

    * ``compute_forward()``
    * ``compute_objectives()``

    The example below illustrates how overriding these two methods is done.

    For more complicated use cases, such as multiple modules that need to
    be updated, the following methods can be overridden:

    * ``fit_batch()``
    * ``evaluate_batch()``

    Arguments
    ---------
    modules : dict of str:torch.nn.Module pairs
        These modules are passed to the optimizer by default if they have
        trainable parameters, and will have ``train()``/``eval()`` called on them.
    opt_class : torch.optim class
        A torch optimizer constructor that has takes only the list of
        parameters (e.g. a lambda or partial function definition). By default,
        this will be passed all modules in ``modules`` at the
        beginning of the ``fit()`` method. This behavior can be changed
        by overriding the ``configure_optimizers()`` method.
    hparams : dict
        Each key:value pair should consist of a string key and a hyperparameter
        that is used within the overridden methods. These will
        be accessible via an ``hparams`` attribute, using "dot" notation:
        e.g., self.hparams.model(x).
    run_opts : dict
        A set of options to change the runtime environment, including

        debug (bool)
            If ``True``, this will only iterate a few batches for all
            datasets, to ensure code runs without crashing.
        debug_batches (int)
            Number of batches to run in debug mode, Default ``2``.
        debug_epochs (int)
            Number of epochs to run in debug mode, Default ``2``.
            If a non-positive number is passed, all epochs are run.
        jit_module_keys (list of str)
            List of keys in ``modules`` that should be jit compiled.
        distributed_count (int)
            Number of devices to run on.
        distributed_backend (str)
            One of ``ddp_nccl``, ``ddp_gloo``, ``ddp_mpi``, ``data_parallel``.
        device (str)
            The location for performing computations.
        auto_mix_prec (bool)
            If ``True``, automatic mixed-precision is used.
            Activate it only with cuda.
        max_grad_norm (float)
            Default implementation of ``fit_batch()`` uses
            ``clip_grad_norm_`` with this value. Default: ``5``.
        nonfinite_patience (int)
            Number of times to ignore non-finite losses before stopping.
            Default: ``3``.
        noprogressbar (bool)
            Whether to turn off progressbar when training. Default: ``False``.
        ckpt_interval_minutes (float)
            Amount of time between saving intra-epoch checkpoints,
            in minutes, default: ``15.0``. If non-positive, these are not saved.
    checkpointer : speechbrain.Checkpointer
        By default, this will be used to load checkpoints, and will have the
        optimizer added to continue training if interrupted.

    Example
    -------
    >>> from torch.optim import SGD
    >>> class SimpleBrain(Brain):
    ...     def compute_forward(self, batch, stage):
    ...         return self.modules.model(batch[0])
    ...     def compute_objectives(self, predictions, batch, stage):
    ...         return torch.nn.functional.l1_loss(predictions, batch[0])
    >>> model = torch.nn.Linear(in_features=10, out_features=10)
    >>> brain = SimpleBrain({"model": model}, opt_class=lambda x: SGD(x, 0.1))
    >>> brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))
    """

    def __init__(  # noqa: C901
        self,
        modules=None,
        opt_class=None,
        hparams=None,
        run_opts=None,
        checkpointer=None,
    ):
        self.opt_class = opt_class
        self.checkpointer = checkpointer

        # Arguments passed via the run opts dictionary
        run_opt_defaults = {
            "debug": False,
            "debug_batches": 2,
            "debug_epochs": 2,
            "device": "cpu",
            "data_parallel_count": -1,
            "data_parallel_backend": False,
            "distributed_launch": False,
            "distributed_backend": "nccl",
            "jit_module_keys": None,
            "auto_mix_prec": False,
            "max_grad_norm": 5.0,
            "nonfinite_patience": 3,
            "noprogressbar": False,
            "ckpt_interval_minutes": 0,
        }
        for arg, default in run_opt_defaults.items():
            if run_opts is not None and arg in run_opts:
                if hparams is not None and arg in hparams:
                    logger.info(
                        "Info: " + arg + " arg overridden by command line input"
                    )
                setattr(self, arg, run_opts[arg])
            else:
                # If any arg from run_opt_defaults exist in hparams and
                # not in command line args "run_opts"
                if hparams is not None and arg in hparams:
                    logger.info(
                        "Info: " + arg + " arg from hparam file is used"
                    )
                    setattr(self, arg, hparams[arg])
                else:
                    setattr(self, arg, default)

        if self.data_parallel_backend and self.distributed_launch:
            sys.exit(
                "To use data_parallel backend, start your script with:\n\t"
                "python experiment.py hyperparams.yaml "
                "--data_parallel_backend=True --data_parallel_count=2"
                "To use DDP backend, start your script with:\n\t"
                "python -m torch.distributed.lunch [args]\n"
                "experiment.py hyperparams.yaml --distributed_launch=True "
                "--distributed_backend=nccl"
            )

        # Switch to the right context
        if "cuda" in self.device:
            torch.cuda.set_device(int(self.device[-1]))

        # Put modules on the right device, accessible with dot notation
        self.modules = torch.nn.ModuleDict(modules).to(self.device)

        # Make hyperparams available with dot notation too
        if hparams is not None:
            self.hparams = SimpleNamespace(**hparams)

        # Checkpointer should point at a temporary directory in debug mode
        if (
            self.debug
            and self.checkpointer is not None
            and hasattr(self.checkpointer, "checkpoints_dir")
        ):
            tempdir = tempfile.TemporaryDirectory()
            logger.info(
                "Since debug mode is active, switching checkpointer "
                f"output to temporary directory: {tempdir.name}"
            )
            self.checkpointer.checkpoints_dir = pathlib.Path(tempdir.name)

            # Keep reference to tempdir as long as checkpointer exists
            self.checkpointer.tempdir = tempdir

        # Sampler should be handled by `make_dataloader`
        # or if you provide a DataLoader directly, you can set
        # this.train_sampler = your_sampler
        # to have your_sampler.set_epoch() called on each epoch.
        self.train_sampler = None

        # Automatic mixed precision init
        if self.auto_mix_prec:
            self.scaler = torch.cuda.amp.GradScaler()

        # List parameter count for the user
        total_params = sum(
            p.numel() for p in self.modules.parameters() if p.requires_grad
        )
        if total_params > 0:
            clsname = self.__class__.__name__
            fmt_num = sb.utils.logger.format_order_of_magnitude(total_params)
            logger.info(f"{fmt_num} trainable parameters in {clsname}")

        if self.distributed_launch:
            self.rank = int(os.environ["RANK"])
            if not torch.distributed.is_initialized():
                if self.rank > 0:
                    sys.exit(
                        " ================ WARNING ==============="
                        "Please add sb.ddp_init_group() into your exp.py"
                        "To use DDP backend, start your script with:\n\t"
                        "python -m torch.distributed.launch [args]\n\t"
                        "experiment.py hyperparams.yaml "
                        "--distributed_launch=True --distributed_backend=nccl"
                    )
                else:
                    logger.warn(
                        "To use DDP, please add "
                        "sb.utils.distributed.ddp_init_group() into your exp.py"
                    )
                    logger.info(
                        "Only the main process is alive, "
                        "all other subprocess were killed."
                    )
            # force the models to start and remain synchronized
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        # Prepare iterating variables
        self.avg_train_loss = 0.0
        self.step = 0

        # Add this class to the checkpointer for intra-epoch checkpoints
        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("brain", self)

    def compute_forward(self, batch, stage):
        """Forward pass, to be overridden by sub-classes.

        Arguments
        ---------
        batch : torch.Tensor or tensors
            An element from the dataloader, including inputs for processing.
        stage : Stage
            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST

        Returns
        -------
        torch.Tensor or Tensors
            The outputs after all processing is complete.
            Directly passed to ``compute_objectives()``.
        """
        raise NotImplementedError

    def compute_objectives(self, predictions, batch, stage):
        """Compute loss, to be overridden by sub-classes.

        Arguments
        ---------
        predictions : torch.Tensor or Tensors
            The output tensor or tensors to evaluate.
            Comes directly from ``compute_forward()``.
        batch : torch.Tensor or tensors
            An element from the dataloader, including targets for comparison.
        stage : Stage
            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST

        Returns
        -------
        loss : torch.Tensor
            A tensor with the computed loss.
        """
        raise NotImplementedError

    def on_stage_start(self, stage, epoch=None):
        """Gets called when a stage starts.

        Useful for defining class variables used during the stage.

        Arguments
        ---------
        stage : Stage
            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
        epoch : int
            The current epoch count.
        """
        pass

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """Gets called at the end of a stage.

        Useful for computing stage statistics, saving checkpoints, etc.

        Arguments
        ---------
        stage : Stage
            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
        stage_loss : float
            The average loss over the completed stage.
        epoch : int
            The current epoch count.
        """
        pass

    def make_dataloader(
        self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs,
    ):
        """Creates DataLoaders for Datasets.

        This is used by ``fit()`` and ``evaluate()`` if they just receive
        Datasets.

        Alternatively, this can be called from outside the Brain subclass.
        In that case, the DataLoader should be passed to ``fit()`` in place
        of the dataset.

        The Stage.TRAIN DataLoader is handled specially. It has extra args for
        shuffle and drop_last. In DDP a DistributedSampler is created (unless
        the dataset is an IterableDataset).

        NOTE
        ----
        Some important DataLoader arguments are passed via **loader_kwargs,
        e.g., batch_size, num_workers, pin_memory.

        NOTE
        ----
        By default, ``evaluate()`` specifies ckpt_prefix=None to stop the test
        DataLoader being added to the checkpointer. If you need to add a
        recoverable after saving checkpoints (e.g., at test time, after
        checkpointing the training), and still be able to recover reasonably,
        you should probably specify ``allow_partial_load=True``.

        Arguments
        ---------
        dataset : Dataset
            A set of data to use to create data loader. If the Dataset is a
            DynamicItemDataset, PaddedBatch is used as the default collate_fn,
            unless specified in loader_kwargs.
        stage : Stage
            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
        ckpt_prefix : str, None
            Prefix to use for SaveableDataLoader Checkpoint name. The Stage
            name is added to this to create the full key. Set to None to not
            save the DataLoader.
        **loader_kwargs : dict
            Additional keyword arguments to the DataLoader.
            E.g., batch_size, num_workers, pin_memory.
        """
        # TRAIN stage is handled specially.
        if stage == sb.Stage.TRAIN:
            loader_kwargs = self._train_loader_specifics(dataset, loader_kwargs)
        dataloader = sb.dataio.dataloader.make_dataloader(
            dataset, **loader_kwargs
        )

        if (
            self.checkpointer is not None
            and ckpt_prefix is not None
            and isinstance(dataloader, SaveableDataLoader)
        ):
            ckpt_key = ckpt_prefix + stage.name
            self.checkpointer.add_recoverable(ckpt_key, dataloader)
        return dataloader

    def _train_loader_specifics(self, dataset, loader_kwargs):
        sampler = loader_kwargs.get("sampler", None)
        # Shuffling should really only matter for the train stage. Shuffling
        # will also lead to more padding in batches if the order was otherwise
        # sorted by length.
        shuffle = loader_kwargs.get("shuffle", False)
        if shuffle and not self.distributed_launch:
            if sampler is not None:
                raise ValueError(
                    "Cannot specify both shuffle=True "
                    "and a sampler in loader_kwargs"
                )
            sampler = ReproducibleRandomSampler(dataset)
            self.train_sampler = sampler
            loader_kwargs["sampler"] = self.train_sampler
            # Delete the shuffle flag, since you cannot specify both a sampler and
            # shuffling:
            del loader_kwargs["shuffle"]

        # Possibly make a DistributedSampler or a wrapper for some other sampler
        if self.distributed_launch and not isinstance(dataset, IterableDataset):
            drop_last = loader_kwargs.get("drop_last", False)
            # num_replicas arg is equal to world_size
            # and retrieved automatically within
            # DistributedSampler obj.
            if sampler is not None:
                self.train_sampler = DistributedSamplerWrapper(
                    sampler,
                    rank=self.rank,
                    drop_last=drop_last,
                    shuffle=shuffle,
                )

                # with DistributedSamplerWrapper, one must disable shuffling for dataloader
                loader_kwargs["shuffle"] = False
            elif loader_kwargs.get("batch_sampler") is None:
                # Currently to get here, shuffle == False, so not passing it.
                # Otherwise we'd have to handle deleting it (but it is already
                # deleted).
                self.train_sampler = DistributedSampler(
                    dataset,
                    rank=self.rank,
                    shuffle=shuffle,
                    drop_last=drop_last,
                )

                # with DistributedSamplerWrapper, one must disable shuffling for dataloader
                loader_kwargs["shuffle"] = False
            else:  # batch_sampler was specified
                # TODO: Could a DistributedSamplerWrapper actually work
                # just fine for wrapping a BatchSampler, as well?
                logger.warning(
                    "Cannot automatically solve distributed sampling "
                    "when using a BatchSampler."
                )
            loader_kwargs["sampler"] = self.train_sampler
        elif self.distributed_launch and isinstance(dataset, IterableDataset):
            logger.warning(
                "Cannot automatically solve distributed sampling "
                "for IterableDataset."
            )
        return loader_kwargs

    def on_fit_start(self):
        """Gets called at the beginning of ``fit()``, on multiple processes
        if ``distributed_count > 0`` and backend is ddp.

        Default implementation compiles the jit modules, initializes
        optimizers, and loads the latest checkpoint to resume training.
        """
        # Run this *after* starting all processes since jit modules cannot be
        # pickled.
        self._compile_jit()

        # Wrap modules with parallel backend after jit
        self._wrap_distributed()

        # Initialize optimizers after parameters are configured
        self.init_optimizers()

        # Load latest checkpoint to resume training if interrupted
        if self.checkpointer is not None:
            self.checkpointer.recover_if_possible(
                device=torch.device(self.device)
            )

    def init_optimizers(self):
        """Called during ``on_fit_start()``, initialize optimizers
        after parameters are fully configured (e.g. DDP, jit).

        The default implementation of this method depends on an optimizer
        class being passed at initialization that takes only a list
        of parameters (e.g., a lambda or a partial function definition).
        This creates a single optimizer that optimizes all trainable params.

        Override this class if there are multiple optimizers.
        """
        if self.opt_class is not None:
            self.optimizer = self.opt_class(self.modules.parameters())

            if self.checkpointer is not None:
                self.checkpointer.add_recoverable("optimizer", self.optimizer)

    def on_evaluate_start(self, max_key=None, min_key=None):
        """Gets called at the beginning of ``evaluate()``

        Default implementation loads the best-performing checkpoint for
        evaluation, based on stored metrics.

        Arguments
        ---------
        max_key : str
            Key to use for finding best checkpoint (higher is better).
            By default, passed to ``self.checkpointer.recover_if_possible()``.
        min_key : str
            Key to use for finding best checkpoint (lower is better).
            By default, passed to ``self.checkpointer.recover_if_possible()``.
        """

        # Recover best checkpoint for evaluation
        if self.checkpointer is not None:
            self.checkpointer.recover_if_possible(
                max_key=max_key,
                min_key=min_key,
                device=torch.device(self.device),
            )

    def fit_batch(self, batch):
        """Fit one batch, override to do multiple updates.

        The default implementation depends on a few methods being defined
        with a particular behavior:

        * ``compute_forward()``
        * ``compute_objectives()``

        Also depends on having optimizers passed at initialization.

        Arguments
        ---------
        batch : list of torch.Tensors
            Batch of data to use for training. Default implementation assumes
            this batch has two elements: inputs and targets.

        Returns
        -------
        detached loss
        """
        # Managing automatic mixed precision
        if self.auto_mix_prec:
            self.optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = self.compute_forward(batch, Stage.TRAIN)
                loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            if self.check_gradients(loss):
                self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            outputs = self.compute_forward(batch, Stage.TRAIN)
            loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
            loss.backward()
            if self.check_gradients(loss):
                self.optimizer.step()
            self.optimizer.zero_grad()

        return loss.detach().cpu()

    def check_gradients(self, loss):
        """Check if gradients are finite and not too large.

        Automatically clips large gradients.

        Arguments
        ---------
        loss : tensor
            The loss tensor after ``backward()`` has been called but
            before the optimizers ``step()``.

        Returns
        -------
        bool
            Whether or not the optimizer step should be carried out.
        """
        if not torch.isfinite(loss):
            self.nonfinite_count += 1

            # Print helpful debug info
            logger.warn(f"Loss is {loss}.")
            for p in self.modules.parameters():
                if not torch.isfinite(p).all():
                    logger.warn("Parameter is not finite: " + str(p))

            # Check if patience is exhausted
            if self.nonfinite_count > self.nonfinite_patience:
                raise ValueError(
                    "Loss is not finite and patience is exhausted. "
                    "To debug, wrap `fit()` with "
                    "autograd's `detect_anomaly()`, e.g.\n\nwith "
                    "torch.autograd.detect_anomaly():\n\tbrain.fit(...)"
                )
            else:
                logger.warn("Patience not yet exhausted, ignoring this batch.")
                return False

        # Clip gradient norm
        torch.nn.utils.clip_grad_norm_(
            (p for p in self.modules.parameters()), self.max_grad_norm
        )

        return True

    def evaluate_batch(self, batch, stage):
        """Evaluate one batch, override for different procedure than train.

        The default implementation depends on two methods being defined
        with a particular behavior:

        * ``compute_forward()``
        * ``compute_objectives()``

        Arguments
        ---------
        batch : list of torch.Tensors
            Batch of data to use for evaluation. Default implementation assumes
            this batch has two elements: inputs and targets.
        stage : Stage
            The stage of the experiment: Stage.VALID, Stage.TEST

        Returns
        -------
        detached loss
        """

        out = self.compute_forward(batch, stage=stage)
        loss = self.compute_objectives(out, batch, stage=stage)
        return loss.detach().cpu()

    def fit(
        self,
        epoch_counter,
        train_set,
        valid_set=None,
        progressbar=None,
        train_loader_kwargs={},
        valid_loader_kwargs={},
    ):
        """Iterate epochs and datasets to improve objective.

        Relies on the existence of multiple functions that can (or should) be
        overridden. The following methods are used and expected to have a
        certain behavior:

        * ``fit_batch()``
        * ``evaluate_batch()``
        * ``update_average()``

        If the initialization was done with distributed_count > 0 and the
        distributed_backend is ddp, this will generally handle multiprocess
        logic, like splitting the training data into subsets for each device and
        only saving a checkpoint on the main process.

        Arguments
        ---------
        epoch_counter : iterable
            Each call should return an integer indicating the epoch count.
        train_set : Dataset, DataLoader
            A set of data to use for training. If a Dataset is given, a
            DataLoader is automatically created. If a DataLoader is given, it is
            used directly.
        valid_set : Dataset, DataLoader
            A set of data to use for validation. If a Dataset is given, a
            DataLoader is automatically created. If a DataLoader is given, it is
            used directly.
        train_loader_kwargs : dict
            Kwargs passed to `make_dataloader()` for making the train_loader
            (if train_set is a Dataset, not DataLoader).
            E.G. batch_size, num_workers.
            DataLoader kwargs are all valid.
        valid_loader_kwargs : dict
            Kwargs passed to `make_dataloader()` for making the valid_loader
            (if valid_set is a Dataset, not DataLoader).
            E.g., batch_size, num_workers.
            DataLoader kwargs are all valid.
        progressbar : bool
            Whether to display the progress of each epoch in a progressbar.
        """

        if not isinstance(train_set, DataLoader):
            train_set = self.make_dataloader(
                train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs
            )
        if valid_set is not None and not isinstance(valid_set, DataLoader):
            valid_set = self.make_dataloader(
                valid_set,
                stage=sb.Stage.VALID,
                ckpt_prefix=None,
                **valid_loader_kwargs,
            )

        self.on_fit_start()

        if progressbar is None:
            progressbar = not self.noprogressbar

        # Iterate epochs
        for epoch in epoch_counter:

            # Training stage
            self.on_stage_start(Stage.TRAIN, epoch)
            self.modules.train()

            # Reset nonfinite count to 0 each epoch
            self.nonfinite_count = 0

            if self.train_sampler is not None and hasattr(
                self.train_sampler, "set_epoch"
            ):
                self.train_sampler.set_epoch(epoch)

            # Time since last intra-epoch checkpoint
            last_ckpt_time = time.time()

            # Only show progressbar if requested and main_process
            enable = progressbar and sb.utils.distributed.if_main_process()
            with tqdm(
                train_set,
                initial=self.step,
                dynamic_ncols=True,
                disable=not enable,
            ) as t:
                for batch in t:
                    self.step += 1
                    loss = self.fit_batch(batch)
                    self.avg_train_loss = self.update_average(
                        loss, self.avg_train_loss
                    )
                    t.set_postfix(train_loss=self.avg_train_loss)

                    # Debug mode only runs a few batches
                    if self.debug and self.step == self.debug_batches:
                        break

                    if (
                        self.checkpointer is not None
                        and self.ckpt_interval_minutes > 0
                        and time.time() - last_ckpt_time
                        >= self.ckpt_interval_minutes * 60.0
                    ):
                        run_on_main(self._save_intra_epoch_ckpt)
                        last_ckpt_time = time.time()

            # Run train "on_stage_end" on all processes
            self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
            self.avg_train_loss = 0.0
            self.step = 0

            # Validation stage
            if valid_set is not None:
                self.on_stage_start(Stage.VALID, epoch)
                self.modules.eval()
                avg_valid_loss = 0.0
                with torch.no_grad():
                    for batch in tqdm(
                        valid_set, dynamic_ncols=True, disable=not enable
                    ):
                        self.step += 1
                        loss = self.evaluate_batch(batch, stage=Stage.VALID)
                        avg_valid_loss = self.update_average(
                            loss, avg_valid_loss
                        )

                        # Debug mode only runs a few batches
                        if self.debug and self.step == self.debug_batches:
                            break

                    # Only run validation "on_stage_end" on main process
                    self.step = 0
                    run_on_main(
                        self.on_stage_end,
                        args=[Stage.VALID, avg_valid_loss, epoch],
                    )

            # Debug mode only runs a few epochs
            if self.debug and epoch == self.debug_epochs:
                break

    def _save_intra_epoch_ckpt(self):
        """Saves a CKPT with specific intra-epoch flag."""
        self.checkpointer.save_and_keep_only(
            end_of_epoch=False,
            num_to_keep=1,
            ckpt_predicate=lambda c: INTRA_EPOCH_CKPT_FLAG in c.meta,
            meta={INTRA_EPOCH_CKPT_FLAG: True},
            verbosity=logging.DEBUG,
        )

    def _compile_jit(self):
        """Compile requested modules with ``torch.jit.script``."""
        if self.jit_module_keys is None:
            return

        for name in self.jit_module_keys:
            if name not in self.modules:
                raise ValueError(
                    "module" + name + " is not defined in your hparams file."
                )
            module = torch.jit.script(self.modules[name])
            self.modules[name] = module.to(self.device)

    def _wrap_distributed(self):
        """Wrap modules with distributed wrapper when requested."""
        if not self.distributed_launch and not self.data_parallel_backend:
            return
        elif self.distributed_launch:
            for name, module in self.modules.items():
                if any(p.requires_grad for p in module.parameters()):
                    # for ddp, all module must run on same GPU
                    module = SyncBatchNorm.convert_sync_batchnorm(module)
                    module = DDP(module, device_ids=[self.device])
                    self.modules[name] = module
        else:
            # data_parallel_backend
            for name, module in self.modules.items():
                if any(p.requires_grad for p in module.parameters()):
                    # if distributed_count = -1 then use all gpus
                    # otherwise, specify the set of gpu to use
                    if self.data_parallel_count == -1:
                        module = DP(module)
                    else:
                        module = DP(
                            module,
                            [i for i in range(self.data_parallel_count)],
                        )
                    self.modules[name] = module

    def evaluate(
        self,
        test_set,
        max_key=None,
        min_key=None,
        progressbar=None,
        test_loader_kwargs={},
    ):
        """Iterate test_set and evaluate brain performance. By default, loads
        the best-performing checkpoint (as recorded using the checkpointer).

        Arguments
        ---------
        test_set : Dataset, DataLoader
            If a DataLoader is given, it is iterated directly. Otherwise passed
            to ``self.make_dataloader()``.
        max_key : str
            Key to use for finding best checkpoint, passed to
            ``on_evaluate_start()``.
        min_key : str
            Key to use for finding best checkpoint, passed to
            ``on_evaluate_start()``.
        progressbar : bool
            Whether to display the progress in a progressbar.
        test_loader_kwargs : dict
            Kwargs passed to ``make_dataloader()`` if ``test_set`` is not a
            DataLoader. NOTE: ``loader_kwargs["ckpt_prefix"]`` gets
            automatically overwritten to ``None`` (so that the test DataLoader
            is not added to the checkpointer).

        Returns
        -------
        average test loss
        """
        if progressbar is None:
            progressbar = not self.noprogressbar

        if not isinstance(test_set, DataLoader):
            test_loader_kwargs["ckpt_prefix"] = None
            test_set = self.make_dataloader(
                test_set, Stage.TEST, **test_loader_kwargs
            )
        self.on_evaluate_start(max_key=max_key, min_key=min_key)
        self.on_stage_start(Stage.TEST, epoch=None)
        self.modules.eval()
        avg_test_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(
                test_set, dynamic_ncols=True, disable=not progressbar
            ):
                self.step += 1
                loss = self.evaluate_batch(batch, stage=Stage.TEST)
                avg_test_loss = self.update_average(loss, avg_test_loss)

                # Debug mode only runs a few batches
                if self.debug and self.step == self.debug_batches:
                    break

            # Only run evaluation "on_stage_end" on main process
            run_on_main(
                self.on_stage_end, args=[Stage.TEST, avg_test_loss, None]
            )
        self.step = 0

    def update_average(self, loss, avg_loss):
        """Update running average of the loss.

        Arguments
        ---------
        loss : torch.tensor
            detached loss, a single float value.
        avg_loss : float
            current running average.

        Returns
        -------
        avg_loss : float
            The average loss.
        """
        if torch.isfinite(loss):
            avg_loss -= avg_loss / (self.step + 1)
            avg_loss += float(loss) / (self.step + 1)
        return avg_loss

    @sb.utils.checkpoints.mark_as_saver
    def _save(self, path):
        save_dict = {
            "step": self.step,
            "avg_train_loss": self.avg_train_loss,
        }
        with open(path, "w") as w:
            w.write(yaml.dump(save_dict))

    @sb.utils.checkpoints.mark_as_loader
    def _recover(self, path, end_of_epoch, device):
        del end_of_epoch
        del device
        with open(path) as f:
            save_dict = yaml.safe_load(f)
        self.step = save_dict["step"]
        self.avg_train_loss = save_dict["avg_train_loss"]
Beispiel #26
0
def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = build_dataset(image_set='val', args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn,
                                   num_workers=args.num_workers)
    data_loader_val = DataLoader(dataset_val,
                                 args.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 collate_fn=utils.collate_fn,
                                 num_workers=args.num_workers)

    if args.dataset_file == "coco_panoptic":
        # We also evaluate AP during panoptic training, on original coco DS
        coco_val = datasets.coco.build("val", args)
        base_ds = get_coco_api_from_dataset(coco_val)
    else:
        base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model'])

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    if args.eval:
        test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
                                              data_loader_val, base_ds, device,
                                              args.output_dir)
        if args.output_dir:
            utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval,
                                 output_dir / "eval.pth")
        return

    #cab
    writer = SummaryWriter("runs/" + args.tb_name)

    best_value = 0

    print("Start training, best_value is " + str(best_value))
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch,
                                      args.clip_max_norm)
        lr_scheduler.step()

        test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
                                              data_loader_val, base_ds, device,
                                              args.output_dir)

        #cab
        for k, v in train_stats.items():
            if isinstance(v, float):
                writer.add_scalar(f'train_{k}', v, epoch)

        new_value = 0
        for k, v in test_stats.items():
            if (isinstance(v, float)):
                writer.add_scalar(f'test_{k}', v, epoch)
            if (k == "coco_eval_bbox"):
                new_value = v[0]
                writer.add_scalar(
                    'Bbox Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]',
                    v[0], epoch)
                writer.add_scalar(
                    'Bbox Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ]',
                    v[1], epoch)
                writer.add_scalar(
                    'Bbox Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ]',
                    v[2], epoch)
                writer.add_scalar(
                    'Bbox Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
                    v[3], epoch)
                writer.add_scalar(
                    'Bbox Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
                    v[4], epoch)
                writer.add_scalar(
                    'Bbox Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]',
                    v[5], epoch)
                writer.add_scalar(
                    'Bbox Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ]',
                    v[6], epoch)
                writer.add_scalar(
                    'Bbox Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ]',
                    v[7], epoch)
                writer.add_scalar(
                    'Bbox Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]',
                    v[8], epoch)
                writer.add_scalar(
                    'Bbox Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
                    v[9], epoch)
                writer.add_scalar(
                    'Bbox Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
                    v[10], epoch)
                writer.add_scalar(
                    'Bbox Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]',
                    v[11], epoch)

            if (k == "coco_eval_masks"):
                new_value = v[0]
                writer.add_scalar(
                    'Mask Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]',
                    v[0], epoch)
                writer.add_scalar(
                    'Mask Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ]',
                    v[1], epoch)
                writer.add_scalar(
                    'Mask Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ]',
                    v[2], epoch)
                writer.add_scalar(
                    'Mask Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
                    v[3], epoch)
                writer.add_scalar(
                    'Mask Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
                    v[4], epoch)
                writer.add_scalar(
                    'Mask Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]',
                    v[5], epoch)
                writer.add_scalar(
                    'Mask Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ]',
                    v[6], epoch)
                writer.add_scalar(
                    'Mask Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ]',
                    v[7], epoch)
                writer.add_scalar(
                    'Mask Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]',
                    v[8], epoch)
                writer.add_scalar(
                    'Mask Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ]',
                    v[9], epoch)
                writer.add_scalar(
                    'Mask Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]',
                    v[10], epoch)
                writer.add_scalar(
                    'Mask Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ]',
                    v[11], epoch)

        print("Epoch finished, best_value is " + str(best_value))

        save_pth = False
        if best_value < new_value:
            best_value = new_value
            save_pth = True

        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
                checkpoint_paths.append(output_dir /
                                        f'checkpoint{epoch:04}.pth')

            if save_pth:
                checkpoint_paths.append(output_dir / f'best.pth')
                bestLog = open(output_dir / 'best_log.txt', 'w+')
                bestLog.write(f'Saved model at epoch {epoch:04}\n')

            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)

        #/cab

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

            # for evaluation logs
            if coco_evaluator is not None:
                (output_dir / 'eval').mkdir(exist_ok=True)
                if "bbox" in coco_evaluator.coco_eval:
                    filenames = ['latest.pth']
                    if epoch % 50 == 0:
                        filenames.append(f'{epoch:03}.pth')
                    for name in filenames:
                        torch.save(coco_evaluator.coco_eval["bbox"].eval,
                                   output_dir / "eval" / name)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Beispiel #27
0
def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))
    print(args)

    device = torch.device(args.device)

    # Fix the seed for reproducibility.
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    dataset_train = build_dataset(image_set='train', args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True)
    data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn, num_workers=args.num_workers)

    # Load from pretrained DETR model.
    assert args.num_queries == 100, args.num_queries
    assert args.enc_layers == 6 and args.dec_layers == 6
    assert args.backbone in ['resnet50', 'resnet101', 'swin'], args.backbone
    if args.backbone == 'resnet50':
        pretrain_model = './data/detr_coco/detr-r50-e632da11.pth'
    elif args.backbone == 'resnet101':
        pretrain_model = './data/detr_coco/detr-r101-2c7b67e5.pth'
    else:
        pretrain_model = None
    if pretrain_model is not None:
        pretrain_dict = torch.load(pretrain_model, map_location='cpu')['model']
        my_model_dict = model_without_ddp.state_dict()
        pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in my_model_dict}
        my_model_dict.update(pretrain_dict)
        model_without_ddp.load_state_dict(my_model_dict)

    output_dir = Path(args.output_dir)
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(
            model, criterion, data_loader_train, optimizer, device, epoch,
            args.clip_max_norm)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 10 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
                checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
            if (epoch + 1) > args.lr_drop and (epoch + 1) % 10 == 0:
                checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master({
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'args': args,
                }, checkpoint_path)

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Beispiel #28
0
def main(args):
    prt.init_distributed_mode(args)
    device = torch.device(args.device)

    model = SlotModel(args)
    print(
        "train model: " +
        f"{'use slot ' if args.use_slot else 'without slot '}" +
        f"{'negetive loss' if args.use_slot and args.loss_status != 1 else 'positive loss'}"
    )
    model.to(device)
    model_without_ddp = model

    if args.thop:

        def freeze_layers(model):
            for layer in model.children():
                if isinstance(layer, torch.nn.Sequential):
                    for sub_layer in layer:
                        sub_layer.requires_grad = False
                        for parameter in sub_layer.parameters():
                            parameter.requires_grad = False
                else:
                    layer.requires_grad = False
                    for parameter in layer.parameters():
                        parameter.requires_grad = False

        def unfreeze_layers(model):
            for layer in model.children():
                if isinstance(layer, torch.nn.Sequential):
                    for sub_layer in layer:
                        sub_layer.requires_grad = True
                        for parameter in sub_layer.parameters():
                            parameter.requires_grad = True
                else:
                    layer.requires_grad = True
                    for parameter in layer.parameters():
                        parameter.requires_grad = True

        unfreeze_layers(model)
        n_parameters = sum(p.numel() for p in model.parameters()
                           if p.requires_grad)
        print(float(n_parameters) / 1000000, 'M')

        freeze_layers(model)
        model.cpu()
        model.eval()
        tl.set_backend('pytorch')

        input_ = torch.randn(1, 3, 260, 260)

        flops_list = []
        params_list = []
        acc_list = []

        flops, params = profile(model, inputs=(input_, ))
        flops_list.append(flops)
        params_list.append(params)
        flops, params = clever_format([flops, params], "%.3f")
        print(float(n_parameters) / 1000000, 'M', params, flops)
        return [float(n_parameters) / 1000000, flops_list[-1] / 1000000000]

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    params = [p for p in model_without_ddp.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr)
    criterion = torch.nn.CrossEntropyLoss()
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=args.lr_drop)

    dataset_train, dataset_val = select_dataset(args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)
    data_loader_train = DataLoaderX(dataset_train,
                                    batch_sampler=batch_sampler_train,
                                    num_workers=args.num_workers)
    data_loader_val = DataLoaderX(dataset_val,
                                  args.batch_size,
                                  sampler=sampler_val,
                                  num_workers=args.num_workers)
    output_dir = Path(args.output_dir)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    print("Start training")
    start_time = time.time()
    log = MetricLog()
    record = log.record
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_one_epoch(model, data_loader_train, optimizer, device, record,
                        epoch)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / (f"{args.dataset}_" + f"{'use_slot_' if args.use_slot else 'no_slot_'}"\
                 + f"{'negative_' if args.use_slot and args.loss_status != 1 else ''}"\
                      + f"{'for_area_size_'+str(args.lambda_value) + '_'+ str(args.slots_per_class) + '_' if args.cal_area_size else ''}" + 'checkpoint.pth')]
            # extra checkpoint before LR drop and every 10 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 10 == 0:
                checkpoint_paths.append(output_dir / (f"{args.dataset}_" + f"{'use_slot_' if args.use_slot else 'no_slot_'}"\
                     + f"{'negative_' if args.use_slot and args.loss_status != 1 else ''}"\
                          + f"{'for_area_size_'+str(args.lambda_value) + '_'+ str(args.slots_per_class) + '_' if args.cal_area_size else ''}" + f'checkpoint{epoch:04}.pth'))
            for checkpoint_path in checkpoint_paths:
                prt.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)

        evaluate(model, data_loader_val, device, record, epoch)
        log.print_metric()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    return [record["train"]["acc"][-1], record["val"]["acc"][-1]]
Beispiel #29
0
def main(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   args.lr_drop,
                                                   gamma=0.9)

    dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = build_dataset(image_set='val', args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn,
                                   num_workers=args.num_workers)
    data_loader_val = DataLoader(dataset_val,
                                 args.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 collate_fn=utils.collate_fn,
                                 num_workers=args.num_workers)

    if args.dataset_file == "coco_panoptic":
        # We also evaluate AP during panoptic training, on original coco DS
        coco_val = datasets.coco.build("val", args)
        base_ds = get_coco_api_from_dataset(coco_val)
    else:
        base_ds = get_coco_api_from_dataset(dataset_val)

    if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model'])

    output_dir = Path(args.output_dir)
    output_dir = output_dir / f"{args.backbone}_{args.transformer_type}"
    if args.output_dir:
        output_dir.mkdir(parents=True, exist_ok=True)

    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')

        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

    if args.eval:
        test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
                                              data_loader_val, base_ds, device,
                                              args.output_dir)
        if args.output_dir:
            utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval,
                                 output_dir / "eval.pth")

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch,
                                      args.clip_max_norm)
        lr_scheduler.step()
        if args.output_dir:
            checkpoint_paths = [output_dir / f'checkpoint_{epoch}.pth']
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
                checkpoint_paths.append(output_dir /
                                        f'checkpoint{epoch}_extra.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)

        test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
                                              data_loader_val, base_ds, device,
                                              args.output_dir)

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

            # for evaluation logs
            if coco_evaluator is not None:
                (output_dir / 'eval').mkdir(exist_ok=True)
                if "bbox" in coco_evaluator.coco_eval:
                    filenames = ['latest.pth']
                    if epoch % 50 == 0:
                        filenames.append(f'{epoch:03}.pth')
                    for name in filenames:
                        torch.save(coco_evaluator.coco_eval["bbox"].eval,
                                   output_dir / "eval" / name)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Beispiel #30
0
def main(rank):
    global args, best_prec1
    args = parser.parse_args()

    device_id = int(os.environ.get('LOCAL_RANK', args.local_rank))
    print("====rank={}    device_id={} ".format(rank, device_id))

    num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(
        args.dataset, args.modality)
    full_arch_name = args.arch
    if args.shift:
        full_arch_name += '_shift{}_{}'.format(args.shift_div,
                                               args.shift_place)
    if args.temporal_pool:
        full_arch_name += '_tpool'
    args.store_name = '_'.join([
        'TSM', args.dataset, args.modality, full_arch_name,
        args.consensus_type,
        'segment%d' % args.num_segments, 'e{}'.format(args.epochs)
    ])
    if args.pretrain != 'imagenet':
        args.store_name += '_{}'.format(args.pretrain)
    if args.lr_type != 'step':
        args.store_name += '_{}'.format(args.lr_type)
    if args.dense_sample:
        args.store_name += '_dense'
    if args.non_local > 0:
        args.store_name += '_nl'
    if args.suffix is not None:
        args.store_name += '_{}'.format(args.suffix)
    print('storing name: ' + args.store_name)

    torch.cuda.set_device(device_id)

    check_rootfolders()

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift,
                shift_div=args.shift_div,
                shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation(
        flip=False
        if 'something' in args.dataset or 'jester' in args.dataset else True)

    #model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    model = model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[device_id],
                                                      output_device=device_id)
    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.resume:
        if args.temporal_pool:  # early temporal pool so that we can load the state_dict
            make_temporal_pool(model.module.base_model, args.num_segments)
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.tune_from:
        print(("=> fine-tuning from '{}'".format(args.tune_from)))
        sd = torch.load(args.tune_from)
        sd = sd['state_dict']
        model_dict = model.state_dict()
        replace_dict = []
        for k, v in sd.items():
            if k not in model_dict and k.replace('.net', '') in model_dict:
                print('=> Load after remove .net: ', k)
                replace_dict.append((k, k.replace('.net', '')))
        for k, v in model_dict.items():
            if k not in sd and k.replace('.net', '') in sd:
                print('=> Load after adding .net: ', k)
                replace_dict.append((k.replace('.net', ''), k))

        for k, k_new in replace_dict:
            sd[k_new] = sd.pop(k)
        keys1 = set(list(sd.keys()))
        keys2 = set(list(model_dict.keys()))
        set_diff = (keys1 - keys2) | (keys2 - keys1)
        print('#### Notice: keys that failed to load: {}'.format(set_diff))
        if args.dataset not in args.tune_from:  # new dataset
            print('=> New dataset, do not load fc weights')
            sd = {k: v for k, v in sd.items() if 'fc' not in k}
        if args.modality == 'Flow' and 'Flow' not in args.tune_from:
            sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
        model_dict.update(sd)
        model.load_state_dict(model_dict)

    if args.temporal_pool and not args.resume:
        make_temporal_pool(model.module.base_model, args.num_segments)

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_dataset = TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dense_sample=args.dense_sample)

    train_sampler = DistributedSampler(train_dataset, shuffle=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

    if rank == 0:
        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            args.root_path,
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ]),
            dense_sample=args.dense_sample),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    scaler = torch.cuda.amp.GradScaler(args.amp)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
        train_sampler.set_epoch(epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training,
              tf_writer, scaler, args.batch_size, args.amp, rank)
        #train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
        # evaluate on validation set

        if rank == 0 and ((epoch + 1) % args.eval_freq == 0
                          or epoch == args.epochs - 1):
            prec1 = validate(val_loader, model, criterion, epoch, log_training,
                             tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)