Example #1
0
    def __call__(self, model, optimizer, criterion, metrics, scheduler, options):
        """Train models and perform validation.

        :param model: a pytorch model to be trained and validated.
        :type model: nn.Module
        :param optimizer: an optimizer for the given model.
        :param criterion: loss function. 
        :param metrics: metrics like TopKAccuracy.
        :param scheduler: a scheduler for hyperparameters.
        :param options: a global object containing all of the options.
        :type options: argparse.Namespace
        """
        # define some parameters for training.
        log.info('There are {} epochs, {} mini-batches per epoch (batch size:{}).'
                 .format(options.train_epochs, options.train_num_batches,
                         options.batch_size), 0)

        # train the model and evaluate the model per args.eval_freq
        max_epochs = min(options.train_epochs, options.max_train_steps)\
            if options.max_train_steps else options.train_epochs
        start_epoch = options.runtime['current_epoch'] if options.resume else 0
        options.runtime['records'] = options.runtime.get('records', [])
        options.runtime['cumu_time_val'] = options.runtime.get('cumu_time_val', [])
        options.runtime['cumu_time_train'] = options.runtime.get('cumu_time_train', [])

        dist.barrier()

        timeit = Timeit(0 if len(options.runtime['cumu_time_val']) == 0
                        else options.runtime['cumu_time_val'][-1])
        for epoch in range(start_epoch, max_epochs):
            options.runtime['current_epoch'] = epoch

            # schedule learning rates
            if options.lr_scheduler_level == 'epoch':
                scheduler.step()

            # Per epoch information.
            log.info("Current epoch : {} : lr={} : time={:10.3e}"
                     .format(epoch, scheduler.get_lr(), timeit.cumu), 0)

            train_epoch(model, optimizer, criterion, scheduler, options, timeit)

            if options.validation:
                timeit.pause()
                do_validate(model, optimizer, criterion, metrics, scheduler, options, timeit)
                timeit.resume()

            if options.repartition_per_epoch:
                options = create_dataset(options, train=True)
                options = create_dataset(options, train=False)
Example #2
0
    def _test_barrier_helper(self, group, group_id, rank):
        WAIT_TIME = 0.3  # seconds

        for dest in group:
            expected_time = torch.DoubleTensor(1).fill_(0.0)
            if dest == rank:
                expected_time.fill_(time.time() + WAIT_TIME)
                dist.broadcast(expected_time, dest, group_id)
                time.sleep(WAIT_TIME + 0.1)  # sleep a little bit longer
                dist.barrier(group_id)
            else:
                dist.broadcast(expected_time, dest, group_id)
                dist.barrier(group_id)
                self.assertGreaterEqual(time.time(), expected_time[0])

        self._barrier()
Example #3
0
    def initialize(self, training=True, force_load_plans=False):
        """
        :param training:
        :return:
        """
        if not self.was_initialized:
            os.makedirs(self.output_folder, exist_ok=True)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            self.folder_with_preprocessed_data = join(
                self.dataset_directory,
                self.plans['data_identifier'] + "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    if self.local_rank == 0:
                        print("unpacking dataset")
                        unpack_dataset(self.folder_with_preprocessed_data)
                        print("done")
                    distributed.barrier()
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                # setting weights for deep supervision losses
                net_numpool = len(self.net_num_pool_op_kernel_sizes)

                # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
                # this gives higher resolution outputs more weight in the loss
                weights = np.array([1 / (2**i) for i in range(net_numpool)])

                # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
                mask = np.array([
                    True if i < net_numpool - 1 else False
                    for i in range(net_numpool)
                ])
                weights[~mask] = 0
                weights = weights / weights.sum()
                self.ds_loss_weights = weights

                seeds_train = np.random.random_integers(
                    0, 99999, self.data_aug_params.get('num_threads'))
                seeds_val = np.random.random_integers(
                    0, 99999,
                    max(self.data_aug_params.get('num_threads') // 2, 1))
                print("seeds train", seeds_train)
                print("seeds_val", seeds_val)
                self.tr_gen, self.val_gen = get_moreDA_augmentation(
                    self.dl_tr,
                    self.dl_val,
                    self.data_aug_params['patch_size_for_spatialtransform'],
                    self.data_aug_params,
                    deep_supervision_scales=self.deep_supervision_scales,
                    seeds_train=seeds_train,
                    seeds_val=seeds_val,
                    pin_memory=self.pin_memory)
                self.print_to_log_file("TRAINING KEYS:\n %s" %
                                       (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                       (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()
            self.network = DDP(self.network, device_ids=[self.local_rank])

        else:
            self.print_to_log_file(
                'self.was_initialized is True, not running self.initialize again'
            )
        self.was_initialized = True
Example #4
0
                    default=2, type=int,
                    help='set the inclusive lower limit for the number of ' +
                    'tensors to be sent during one test run; ' +
                    'default: 2 (10**2 = 100)')

args = parser.parse_args()

MIN_NUM_TENSORS = args.min_num_tensors
MIN_BYTES = args.min_bytes
MAX_NUM_TENSORS = args.max_num_tensors + 1
MAX_BYTES = args.max_bytes + 1

dist.init_process_group(backend=os.environ['BACKEND'])

rank = dist.get_rank()
dist.barrier()

if rank == 0:
    print_header("broadcast")
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes).fill_(42)
        for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
            start = timer()
            for i in range(0, num_tensors):
                dist.broadcast(tensor, 0)
            end = timer()
            print_stats(bytes, num_tensors, end - start)
    print()
else:
    for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
        tensor = torch.ByteTensor(bytes)
Example #5
0
    def fid(self, fid_num, z=None, ignore_cache=False, align_tf=True):
        """Computes the FID metric."""
        self.set_mode('val')

        if self.val_loader is None:
            self.build_dataset('val')
        fid_num = min(fid_num, len(self.val_loader.dataset))

        if self.inception_model is None:
            if align_tf:
                self.logger.info(f'Building inception model '
                                 f'(aligned with TensorFlow) ...')
            else:
                self.logger.info(f'Building inception model '
                                 f'(using torchvision) ...')
            self.inception_model = build_inception_model(align_tf).cuda()
            self.logger.info(f'Finish building inception model.')

        if z is not None:
            assert isinstance(z, np.ndarray)
            assert z.ndim == 2 and z.shape[1] == self.z_space_dim
            fid_num = min(fid_num, z.shape[0])
            z = torch.from_numpy(z).type(torch.FloatTensor)
        if not fid_num:
            return -1

        indices = list(range(self.rank, fid_num, self.world_size))

        self.logger.init_pbar()

        # Extract features from fake images.
        fake_feature_list = []
        task1 = self.logger.add_pbar_task('Fake', total=fid_num)
        for batch_idx in range(0, len(indices), self.val_batch_size):
            sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
            batch_size = len(sub_indices)
            if z is None:
                code = torch.randn(batch_size, self.z_space_dim).cuda()
            else:
                code = z[sub_indices].cuda()
            with torch.no_grad():
                if 'generator_smooth' in self.models:
                    G = self.models['generator_smooth']
                else:
                    G = self.models['generator']
                fake_images = G(code)['image']
                fake_feature_list.append(
                    extract_feature(self.inception_model, fake_images))
            self.logger.update_pbar(task1, batch_size * self.world_size)
        np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy',
                np.concatenate(fake_feature_list, axis=0))

        # Extract features from real images if needed.
        cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy'
        do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
        if do_real_test:
            real_feature_list = []
            task2 = self.logger.add_pbar_task("Real", total=fid_num)
            for batch_idx in range(0, len(indices), self.val_batch_size):
                sub_indices = indices[batch_idx:batch_idx +
                                      self.val_batch_size]
                batch_size = len(sub_indices)
                data = next(self.val_loader)
                for key in data:
                    data[key] = data[key][:batch_size].cuda(
                        torch.cuda.current_device(), non_blocking=True)
                with torch.no_grad():
                    real_images = data['image']
                    real_feature_list.append(
                        extract_feature(self.inception_model, real_images))
                self.logger.update_pbar(task2, batch_size * self.world_size)
            np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
                    np.concatenate(real_feature_list, axis=0))

        dist.barrier()
        if self.rank != 0:
            return -1
        self.logger.close_pbar()

        # Collect fake features.
        fake_feature_list.clear()
        for rank in range(self.world_size):
            fake_feature_list.append(
                np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy'))
            os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy')
        fake_features = np.concatenate(fake_feature_list, axis=0)
        assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num
        feature_dim = fake_features.shape[1]
        pad = fid_num % self.world_size
        if pad:
            pad = self.world_size - pad
        fake_features = np.pad(fake_features, ((0, pad), (0, 0)))
        fake_features = fake_features.reshape(self.world_size, -1, feature_dim)
        fake_features = fake_features.transpose(1, 0, 2)
        fake_features = fake_features.reshape(-1, feature_dim)[:fid_num]

        # Collect (or load) real features.
        if do_real_test:
            real_feature_list.clear()
            for rank in range(self.world_size):
                real_feature_list.append(
                    np.load(f'{self.work_dir}/real_fid_features_{rank}.npy'))
                os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy')
            real_features = np.concatenate(real_feature_list, axis=0)
            assert real_features.shape == (fid_num, feature_dim)
            real_features = np.pad(real_features, ((0, pad), (0, 0)))
            real_features = real_features.reshape(self.world_size, -1,
                                                  feature_dim)
            real_features = real_features.transpose(1, 0, 2)
            real_features = real_features.reshape(-1, feature_dim)[:fid_num]
            np.save(cached_fid_file, real_features)
        else:
            real_features = np.load(cached_fid_file)
            assert real_features.shape == (fid_num, feature_dim)

        fid_value = compute_fid(fake_features, real_features)
        return fid_value
def main(args):
    try:
        os.makedirs(args.checkpoint_path)
    except OSError as e:
        if e.errno == errno.EEXIST:
            print('Directory already exists.')
        else:
            raise

    print('loading dataset')
    train_loader, train_sampler, valid_loader = get_dataset(args)

    print('building model')
    model = get_model(args)

    if args.optim == 'adam':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               args.learning_rate,
                               betas=(args.alpha, args.beta),
                               eps=args.epsilon)
    elif args.optim == 'sgd':  # original implementation in the paper
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              args.learning_rate,
                              weight_decay=1e-4,
                              momentum=args.alpha,
                              nesterov=True)
    else:
        assert False, "only support adam or sgd"

    # learning rate decay
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=args.reduce_factor,
                                               patience=args.patience_epoch,
                                               verbose=True)

    best_loss = float('inf')

    if args.enable_visdom:
        import visdom
        vis = visdom.Visdom(env='weakly-supervised')
        vis_window = {'iter': None, 'loss': None}

    all_cls_losses = []
    all_training_losses = []

    for train_epoch in range(args.max_epochs):
        t_epoch_start = time.time()
        print('Epoch: {}'.format(train_epoch))

        if args.distributed:
            train_sampler.set_epoch(train_epoch)

        epoch_loss = train(train_epoch,
                           model,
                           optimizer,
                           train_loader,
                           args,
                           vis=None,
                           vis_window=None)
        all_training_losses.append(epoch_loss)

        val_cls_loss = valid(model, valid_loader)
        all_cls_losses.append(val_cls_loss)

        # learning rate decay
        scheduler.step(val_cls_loss)

        if args.enable_visdom:
            if vis_window['loss'] is None:
                if not args.distributed or (args.distributed
                                            and dist.get_rank() == 0):
                    vis_window['loss'] = vis.line(
                        X=np.tile(np.arange(len(all_cls_losses)), (2, 1)).T,
                        Y=np.column_stack((np.asarray(all_training_losses),
                                           np.asarray(all_cls_losses))),
                        opts=dict(title='Loss',
                                  xlabel='Validation Iter',
                                  ylabel='Loss',
                                  legend=['train', 'dev_cls']))
            else:
                if not args.distributed or (args.distributed
                                            and dist.get_rank() == 0):
                    vis.line(X=np.tile(np.arange(len(all_cls_losses)),
                                       (2, 1)).T,
                             Y=np.column_stack(
                                 (np.asarray(all_training_losses),
                                  np.asarray(all_cls_losses))),
                             win=vis_window['loss'],
                             opts=dict(title='Loss',
                                       xlabel='Validation Iter',
                                       ylabel='Loss',
                                       legend=['train', 'dev_cls']))

        if val_cls_loss < best_loss:
            best_loss = val_cls_loss
            if (args.distributed
                    and dist.get_rank() == 0) or not args.distributed:
                torch.save(
                    model.module.state_dict(),
                    os.path.join(args.checkpoint_path, 'model_best_loss.t7'))
            print('*' * 5)
            print('Better validation loss {:.4f} found, save model'.format(
                val_cls_loss))

        # save eval and train losses
        if (args.distributed and dist.get_rank() == 0) or not args.distributed:
            torch.save(
                {
                    'train_loss': all_training_losses,
                    'eval_cls_loss': all_cls_losses,
                }, os.path.join(args.checkpoint_path, 'model_losses.t7'))

        # validation/save checkpoint every few epochs
        if train_epoch % args.save_checkpoint_every == 0 or train_epoch == args.max_epochs:
            if (args.distributed
                    and dist.get_rank() == 0) or not args.distributed:
                torch.save(
                    model.module.state_dict(),
                    os.path.join(args.checkpoint_path,
                                 'model_epoch_{}.t7'.format(train_epoch)))

        # all other process wait for the 1st process to finish
        if args.distributed:
            dist.barrier()

        print('-' * 80)
        print('Epoch {} summary'.format(train_epoch))
        print('Train loss: {:.4f}, val loss: {:.4f}, Time: {:.4f}s'.format(
            epoch_loss, val_cls_loss,
            time.time() - t_epoch_start))
        print('-' * 80)
Example #7
0
def run(rank, size, inputs, adj_matrix, data, features, classes, device):
    global epochs
    global mid_layer
    global run
    global timing

    best_val_acc = test_acc = 0
    outputs = None
    group = dist.new_group(list(range(size)))

    if rank >= size:
        return

    # adj_matrix_loc = torch.rand(node_count, n_per_proc)
    # inputs_loc = torch.rand(n_per_proc, inputs.size(1))

    inputs_loc, adj_matrix_loc, am_pbyp = oned_partition(
        rank, size, inputs, adj_matrix, data, features, classes, device)

    inputs_loc = inputs_loc.to(device)
    adj_matrix_loc = adj_matrix_loc.to(device)
    for i in range(len(am_pbyp)):
        am_pbyp[i] = am_pbyp[i].t().coalesce().to(device)

    for i in range(run_count):
        run = i
        torch.manual_seed(0)
        weight1_nonleaf = torch.rand(features, mid_layer, requires_grad=True)
        weight1_nonleaf = weight1_nonleaf.to(device)
        weight1_nonleaf.retain_grad()

        weight2_nonleaf = torch.rand(mid_layer, classes, requires_grad=True)
        weight2_nonleaf = weight2_nonleaf.to(device)
        weight2_nonleaf.retain_grad()

        weight1 = Parameter(weight1_nonleaf)
        weight2 = Parameter(weight2_nonleaf)

        optimizer = torch.optim.Adam([weight1, weight2], lr=0.01)
        dist.barrier(group)

        tstart = 0.0
        tstop = 0.0

        total_time[i] = dict()
        comm_time[i] = dict()
        comp_time[i] = dict()
        scomp_time[i] = dict()
        dcomp_time[i] = dict()
        bcast_comm_time[i] = dict()
        barrier_time[i] = dict()
        barrier_subset_time[i] = dict()
        op1_comm_time[i] = dict()
        op2_comm_time[i] = dict()

        total_time[i][rank] = 0.0
        comm_time[i][rank] = 0.0
        comp_time[i][rank] = 0.0
        scomp_time[i][rank] = 0.0
        dcomp_time[i][rank] = 0.0
        bcast_comm_time[i][rank] = 0.0
        barrier_time[i][rank] = 0.0
        barrier_subset_time[i][rank] = 0.0
        op1_comm_time[i][rank] = 0.0
        op2_comm_time[i][rank] = 0.0

        timing_on = timing == True
        timing = False
        outputs = train(inputs_loc, weight1, weight2, adj_matrix_loc, am_pbyp,
                        optimizer, data, rank, size, group)
        if timing_on:
            timing = True

        dist.barrier(group)
        tstart = time.time()

        # for epoch in range(1, 201):
        print(f"Starting training... rank {rank} run {i}", flush=True)
        for epoch in range(1, epochs):
            outputs = train(inputs_loc, weight1, weight2, adj_matrix_loc,
                            am_pbyp, optimizer, data, rank, size, group)
            print("Epoch: {:03d}".format(epoch), flush=True)

        # dist.barrier(group)
        tstop = time.time()
        total_time[i][rank] = tstop - tstart

    # Get median runtime according to rank0 and print that run's breakdown
    dist.barrier(group)
    if rank == 0:
        total_times_r0 = []
        for i in range(run_count):
            total_times_r0.append(total_time[i][0])

        print(f"total_times_r0: {total_times_r0}")
        median_run_time = statistics.median(total_times_r0)
        median_idx = total_times_r0.index(median_run_time)
        median_idx = torch.cuda.LongTensor([median_idx])
    else:
        median_idx = torch.cuda.LongTensor([0])

    dist.broadcast(median_idx, src=0, group=group)
    median_idx = median_idx.item()
    print(f"rank: {rank} median_run: {median_idx}")
    print(f"rank: {rank} total_time: {total_time[median_idx][rank]}")
    print(f"rank: {rank} comm_time: {comm_time[median_idx][rank]}")
    print(f"rank: {rank} comp_time: {comp_time[median_idx][rank]}")
    print(f"rank: {rank} scomp_time: {scomp_time[median_idx][rank]}")
    print(f"rank: {rank} dcomp_time: {dcomp_time[median_idx][rank]}")
    print(f"rank: {rank} bcast_comm_time: {bcast_comm_time[median_idx][rank]}")
    print(f"rank: {rank} barrier_time: {barrier_time[median_idx][rank]}")
    print(
        f"rank: {rank} barrier_subset_time: {barrier_subset_time[median_idx][rank]}"
    )
    print(f"rank: {rank} op1_comm_time: {op1_comm_time[median_idx][rank]}")
    print(f"rank: {rank} op2_comm_time: {op2_comm_time[median_idx][rank]}")
    print(f"rank: {rank} {outputs}")

    if accuracy:
        # All-gather outputs to test accuracy
        output_parts = []
        n_per_proc = math.ceil(float(inputs.size(0)) / size)
        # print(f"rows: {am_pbyp[-1].size(0)} cols: {classes}", flush=True)
        for i in range(size):
            output_parts.append(
                torch.cuda.FloatTensor(n_per_proc, classes,
                                       device=device).fill_(0))

        if outputs.size(0) != n_per_proc:
            pad_row = n_per_proc - outputs.size(0)
            outputs = torch.cat(
                (outputs,
                 torch.cuda.FloatTensor(pad_row, classes, device=device)),
                dim=0)

        dist.all_gather(output_parts, outputs)
        output_parts[rank] = outputs

        padding = inputs.size(0) - n_per_proc * (size - 1)
        output_parts[size - 1] = output_parts[size - 1][:padding, :]

        outputs = torch.cat(output_parts, dim=0)

        train_acc, val_acc, tmp_test_acc = test(outputs, data,
                                                am_pbyp[0].size(1), rank)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc
        log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'

        print(log.format(900, train_acc, best_val_acc, test_acc))

    return outputs
Example #8
0
    def _run(self, tempdir):
        my_rank = dist.get_rank()
        fnames = ["aaa" * 300, "bbb" * 301, "ccc" * 302]

        metrics_saver = MetricsSaver(
            save_dir=tempdir,
            metrics=["metric1", "metric2"],
            metric_details=["metric3", "metric4"],
            batch_transform=lambda x: x[PostFix.meta("image")],
            summary_ops="*",
            delimiter="\t",
        )

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)

        if my_rank == 0:
            data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}]

            @engine.on(Events.EPOCH_COMPLETED)
            def _save_metrics0(engine):
                engine.state.metrics = {"metric1": 1, "metric2": 2}
                engine.state.metric_details = {
                    "metric3": torch.tensor([[1, 2]]),
                    "metric4": torch.tensor([[5, 6]])
                }

        if my_rank == 1:
            # different ranks have different data length
            data = [
                {
                    PostFix.meta("image"): {
                        "filename_or_obj": [fnames[1]]
                    }
                },
                {
                    PostFix.meta("image"): {
                        "filename_or_obj": [fnames[2]]
                    }
                },
            ]

            @engine.on(Events.EPOCH_COMPLETED)
            def _save_metrics1(engine):
                engine.state.metrics = {"metric1": 1, "metric2": 2}
                engine.state.metric_details = {
                    "metric3": torch.tensor([[2, 3], [3, 4]]),
                    "metric4": torch.tensor([[6, 7], [7, 8]]),
                }

        @engine.on(Events.EPOCH_COMPLETED)
        def _all_gather(engine):
            scores = engine.state.metric_details["metric3"]
            engine.state.metric_details[
                "metric3"] = evenly_divisible_all_gather(data=scores,
                                                         concat=True)
            scores = engine.state.metric_details["metric4"]
            engine.state.metric_details[
                "metric4"] = evenly_divisible_all_gather(data=scores,
                                                         concat=True)

        metrics_saver.attach(engine)
        engine.run(data, max_epochs=1)

        if my_rank == 0:
            # check the metrics.csv and content
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metrics.csv")))
            with open(os.path.join(tempdir, "metrics.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_raw.csv")))
            # check the metric_raw.csv and content
            with open(os.path.join(tempdir, "metric3_raw.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i > 0:
                        expected = [
                            f"{fnames[i-1]}\t{float(i):.4f}\t{float(i + 1):.4f}\t{i + 0.5:.4f}"
                        ]
                        self.assertEqual(row, expected)
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_summary.csv")))
            # check the metric_summary.csv and content
            with open(os.path.join(tempdir, "metric3_summary.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i == 1:
                        self.assertEqual(row, [
                            "class0\t2.0000\t2.0000\t3.0000\t1.0000\t2.8000\t0.8165\t3.0000"
                        ])
                    elif i == 2:
                        self.assertEqual(row, [
                            "class1\t3.0000\t3.0000\t4.0000\t2.0000\t3.8000\t0.8165\t3.0000"
                        ])
                    elif i == 3:
                        self.assertEqual(row, [
                            "mean\t2.5000\t2.5000\t3.5000\t1.5000\t3.3000\t0.8165\t3.0000"
                        ])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_raw.csv")))
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_summary.csv")))
        dist.barrier()
Example #9
0
def _train_worker(
    process_rank: int,
    params: Params,
    serialization_dir: str,
    file_friendly_logging: bool = False,
    recover: bool = False,
    cache_directory: str = None,
    cache_prefix: str = None,
    include_package: List[str] = None,
    node_rank: int = 0,
    master_addr: str = "127.0.0.1",
    master_port: int = 29500,
    world_size: int = 1,
    distributed_device_ids: List[str] = None,
) -> Optional[Model]:
    """
    Helper to train the configured model/experiment. In distributed mode, this is spawned as a
    worker process. In a single GPU experiment, this returns the ``Model`` object and in distributed
    training, nothing is returned.

    Parameters
    ----------
    process_rank : ``int``
        The process index that is initialized using the GPU device id.
    params : ``Params``
        A parameter object specifying an AllenNLP Experiment.
    serialization_dir : ``str``
        The directory in which to save results and logs.
    file_friendly_logging : ``bool``, optional (default=False)
        If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
        down tqdm's output to only once every 10 seconds.
    recover : ``bool``, optional (default=False)
        If ``True``, we will try to recover a training run from an existing serialization
        directory.  This is only intended for use when something actually crashed during the middle
        of a run.  For continuing training a model on new data, see the ``fine-tune`` command.
    cache_directory : ``str``, optional
        For caching data pre-processing.  See :func:`allennlp.training.util.datasets_from_params`.
    cache_prefix : ``str``, optional
        For caching data pre-processing.  See :func:`allennlp.training.util.datasets_from_params`.
    include_package: ``List[str]``, optional
        In distributed mode, since this function would have been spawned as a separate process,
        the extra imports need to be done again. NOTE: This does not have any effect in single
        GPU training.
    node_rank: ``int``, optional
        Rank of the node
    world_size: ``int``, optional
        The number of processes involved in distributed training.

    Returns
    -------
    best_model: ``Model``
        The model with the best epoch weights.
    """
    prepare_global_logging(serialization_dir,
                           file_friendly_logging,
                           rank=process_rank,
                           world_size=world_size)
    prepare_environment(params)

    distributed = world_size > 1

    # not using `allennlp.common.util.is_master` as the process group is yet to be initialized
    master = process_rank == 0

    evaluate_on_test = params.pop_bool("evaluate_on_test", False)

    if distributed:
        # Since the worker is spawned and not forked, the extra imports
        # need to be done again.
        if include_package is not None:
            for package_name in include_package:
                import_submodules(package_name)

        num_procs_per_node = len(distributed_device_ids)
        # The Unique identifier of the worker process among all the processes in the
        # distributed training group is computed here. This is used while initializing
        # the process group using `init_process_group`
        global_rank = node_rank * num_procs_per_node + process_rank

        # In distributed training, the configured device is always going to be a list.
        # The corresponding gpu id for the particular worker is obtained by picking the id
        # from the device list with the rank as index
        gpu_id = distributed_device_ids[process_rank]  # type: ignore

        # Till now, "cuda_device" might not be set in the trainer params.
        # But a worker trainer needs to only know about its specific GPU id.
        params["trainer"]["cuda_device"] = gpu_id
        params["trainer"]["world_size"] = world_size
        params["trainer"]["distributed"] = True

        torch.cuda.set_device(gpu_id)
        dist.init_process_group(
            backend="nccl",
            init_method=f"tcp://{master_addr}:{master_port}",
            world_size=world_size,
            rank=global_rank,
        )
        logging.info(f"Process group of world size {world_size} initialized "
                     f"for distributed training in worker {global_rank}")

    trainer_type = params.get("trainer", {}).get("type", "default")

    if trainer_type == "default":
        # Special logic to instantiate backward-compatible trainer.
        pieces = TrainerPieces.from_params(params, serialization_dir, recover,
                                           cache_directory, cache_prefix)
        trainer = Trainer.from_params(
            model=pieces.model,
            serialization_dir=serialization_dir,
            iterator=pieces.iterator,
            train_data=pieces.train_dataset,
            validation_data=pieces.validation_dataset,
            params=pieces.params,
            validation_iterator=pieces.validation_iterator,
        )

        evaluation_iterator = pieces.validation_iterator or pieces.iterator
        evaluation_dataset = pieces.test_dataset

    else:
        if evaluate_on_test:
            raise ValueError(
                "--evaluate-on-test only works with the default Trainer. "
                "If you're using the CallbackTrainer you can use a callback "
                "to evaluate at Events.TRAINING_END; otherwise you'll have "
                "to run allennlp evaluate separately.")

        trainer = TrainerBase.from_params(params, serialization_dir, recover,
                                          cache_directory, cache_prefix)
        evaluation_dataset = None

    params.assert_empty("base train command")

    try:
        if distributed:  # let the setup get ready for all the workers
            dist.barrier()

        metrics = trainer.train()
    except KeyboardInterrupt:
        # if we have completed an epoch, try to create a model archive.
        if master and os.path.exists(
                os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
            logging.info(
                "Training interrupted by the user. Attempting to create "
                "a model archive using the current best epoch weights.")
            archive_model(serialization_dir,
                          files_to_archive=params.files_to_archive)
        raise

    if master:
        if evaluation_dataset and evaluate_on_test:
            logger.info(
                "The model will be evaluated using the best epoch weights.")
            test_metrics = evaluate(
                trainer.model,
                evaluation_dataset,
                evaluation_iterator,
                cuda_device=trainer.cuda_device,
                # TODO(brendanr): Pass in an arg following Joel's trainer refactor.
                batch_weight_key="",
            )

            for key, value in test_metrics.items():
                metrics["test_" + key] = value
        elif evaluation_dataset:
            logger.info(
                "To evaluate on the test set after training, pass the "
                "'evaluate_on_test' flag, or use the 'allennlp evaluate' command."
            )
        dump_metrics(os.path.join(serialization_dir, "metrics.json"),
                     metrics,
                     log=True)

    if not distributed:
        return trainer.model

    return None  # to make mypy happy
 def test_barrier(self):
     # nothing to verify. Just run it through.
     dist.barrier()
Example #11
0
def main(rank: int, num_random_shapes: int):
    if rank == 0:
        print(torch.__version__)
        print(f'cudnn_ver: {torch.backends.cudnn.version()}')

    report_every = min(1000, max(1, num_random_shapes // 20))
    t_start = time.time()

    device = f'cuda:{rank}'
    torch.cuda.set_device(device)
    print(f'#{rank}', torch.cuda.get_device_name(rank))

    calculated_shapes = 0
    mismatches = [0, 0, 0]  # forward, dgrad, wgrad
    exception_shapes = 0
    con_exceptions = Counter()

    d_size = {torch.float: 4, torch.half: 2}

    while calculated_shapes < num_random_shapes:
        n = random.randint(1, 8)
        c = random.randint(8 // 8, 512 // 8) * 8
        d = random.randint(8 // 4, 512 // 4) * 4
        # h = random.randint(4, 512)
        # w = random.randint(4, 512)
        h = d
        w = d

        ks = random.choice([1, 2, 3, 5])
        dtype = random.choice([torch.half, torch.float])

        tensor_size = (
            2 * 2 * n * c * d * h * w +  # (input + input.grad) and ref
            2 * 2 * n * c +  # (weight + weight.grad) and ref
            2 * 2 * n * c * d * h * w  # (output (est) + output.grad) and ref
        ) * d_size[dtype] / 1e9
        if tensor_size > 4.0:
            # print('tensor too large, continue')
            continue

        oom = False

        try:
            x = torch.randn(n, c, d, h, w, dtype=dtype,
                            device=device).requires_grad_()
            x.to(memory_format=torch.channels_last_3d)
            ref_cont_x = x.detach().clone().contiguous().requires_grad_()

            net = torch.nn.BatchNorm3d(c)
            net = net.to(dtype=dtype,
                         device=device,
                         memory_format=torch.channels_last_3d)
            ref_cont_net = torch.nn.BatchNorm3d(c)
            ref_cont_net = ref_cont_net.to(
                dtype=dtype,
                device=device,
                memory_format=torch.channels_last_3d)

            with torch.no_grad():
                for p, rp in zip(net.parameters(), ref_cont_net.parameters()):
                    rp.copy_(p)

            out = net(x)
            ref_cont_out = ref_cont_net(ref_cont_x)

            out.sum().backward()
            ref_cont_out.sum().backward()

            _a, _b = _compare_tensors_internal(out,
                                               ref_cont_out,
                                               atol=1e-3,
                                               rtol=1e-3,
                                               equal_nan=False)
            if not _a:
                mismatches[0] += 1
            _c, _d = _compare_tensors_internal(x.grad,
                                               ref_cont_x.grad,
                                               atol=1e-3,
                                               rtol=1e-3,
                                               equal_nan=False)
            if not _c:
                mismatches[1] += 1
            _e, _f = _compare_tensors_internal(net.weight.grad,
                                               ref_cont_net.weight.grad,
                                               atol=1e-3,
                                               rtol=1e-3,
                                               equal_nan=False)
            if not _e:
                mismatches[2] += 1
        except RuntimeError as e:
            exc_type, exc_value, exc_traceback = sys.exc_info()
            if str(e).startswith('CUDA out of memory'):
                oom = True
            else:
                print(f'*************** {rank=} {n=} {c=} {d=} {h=} {w=} \n'
                      f'*************** {dtype=} {tensor_size=:.3f} GB\n'
                      f'*************** {exc_type}: {exc_value}')
                con_exceptions[str(exc_value)] += 1

                # traceback.print_exc()
                exception_shapes += 1
                # raise

        if oom:
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            continue

        calculated_shapes += 1

        if calculated_shapes // report_every != (calculated_shapes -
                                                 1) // report_every:
            t_now = time.time()
            tl_est = (num_random_shapes - calculated_shapes) * (
                t_now - t_start) / calculated_shapes
            print(
                f'#{rank} time cost = {t_now - t_start :.3f}, {calculated_shapes = }, time left (est) = {tl_est :.3f}'
            )

    print(f'#{rank} Exceptions:', json.dumps(con_exceptions, indent=2))

    dist.barrier()

    l_report = [*mismatches, exception_shapes]
    t_report = torch.Tensor(l_report).cuda()
    dist.reduce(t_report, dst=0)
    if rank == 0:
        print()
        print(f'total shapes = {num_random_shapes * dist.get_world_size()}')
        print('mismatches: forward, dgrad, wgrad; num of exception shapes')
        print(t_report.cpu().numpy())
    def _worker(gpu_id: int, sync_file: str, world_size: int):
        torch.manual_seed(0)
        os.environ["RANK"] = str(gpu_id)
        init_distributed_on_file(
            world_size=world_size, gpu_id=gpu_id, sync_file=sync_file
        )
        torch.backends.cudnn.deterministic = True

        config = TestCheckpointConversion._create_fsdp_model_config(with_fsdp=True)
        model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG)
        optimizer = optim.SGD(model.parameters(), lr=1e-4)

        # Fake inputs
        num_iterations = 5
        batch_size = 3
        torch.manual_seed(gpu_id)
        fake_inputs = torch.randn(size=(num_iterations, batch_size, 3, 96, 96))
        fake_targets = torch.randn(size=(num_iterations, batch_size))

        # Fake training loop
        criterion = nn.MSELoss()
        for iteration in range(num_iterations):
            fake_input = fake_inputs[iteration].cuda(gpu_id)
            fake_target = fake_targets[iteration].cuda(gpu_id)
            output1, output2 = model(fake_input)[0]
            loss = criterion(output1.sum(axis=-1), fake_target) + criterion(
                output2.sum(axis=-1), fake_target
            )
            if gpu_id == 0:
                print(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Save a bunch of checkpoint, one by shard
        checkpoint_writer = CheckpointWriter(
            checkpoint_folder=".",
            is_final_train_phase=True,
            mode="iteration",
            mode_num=0,
            backend="disk",
        )
        content = {
            "classy_state_dict": {
                "base_model": {
                    "model": {"trunk": model.trunk.local_state_dict()},
                    "meta": {"trunk": model.trunk.local_metadata_dict()},
                }
            }
        }
        checkpoint_writer.save_sharded_checkpoint(
            content, shard_rank=gpu_id, world_size=world_size
        )
        dist.barrier()
        print(os.listdir("."))

        # Convert the checkpoint to consolidated and sliced checkpoints
        if gpu_id == 0:
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                "checkpoint.torch", "checkpoint_conso.torch"
            )
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                "checkpoint.torch", "checkpoint_sliced.torch"
            )
        dist.barrier()
        print(os.listdir("."))

        # Now create models initialized from the previous checkpoint and compare them
        fake_test_input = torch.randn(size=(1, 3, 96, 96)).cuda(gpu_id)

        shard_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint.torch", device=torch.device("cpu")
        )
        shard_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        shard_model = fsdp_wrapper(shard_model, **config.MODEL.FSDP_CONFIG)
        shard_model.init_model_from_weights_params_file(config, shard_cp)

        conso_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_conso.torch", device=torch.device("cpu")
        )
        conso_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        conso_model = fsdp_wrapper(conso_model, **config.MODEL.FSDP_CONFIG)
        conso_model.init_model_from_weights_params_file(config, conso_cp)

        slice_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_sliced.torch", device=torch.device("cpu")
        )
        slice_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        slice_model = fsdp_wrapper(slice_model, **config.MODEL.FSDP_CONFIG)
        slice_model.init_model_from_weights_params_file(config, slice_cp)

        # Verifying that the models are equivalent
        if gpu_id == 0:
            slice_state_dict = slice_model.local_state_dict()
            conso_state_dict = conso_model.local_state_dict()
            assert set(slice_state_dict.keys()) == set(conso_state_dict.keys())
            for k in slice_state_dict.keys():
                slice_val = slice_state_dict[k]
                conso_val = conso_state_dict[k]
                assert torch.allclose(
                    slice_val, conso_val
                ), f"Difference for key {k}: {slice_val} VS {conso_val}"
        dist.barrier()

        with torch.no_grad():
            ref_out = model.trunk(fake_test_input)[0]
            shard_out = shard_model.trunk(fake_test_input)[0]
            conso_out = conso_model.trunk(fake_test_input)[0]
            slice_out = slice_model.trunk(fake_test_input)[0]
            assert torch.allclose(
                ref_out, shard_out
            ), f"{ref_out.sum()} vs {shard_out.sum()}"
            assert torch.allclose(
                ref_out, conso_out
            ), f"{ref_out.sum()} vs {conso_out.sum()}"
            assert torch.allclose(
                ref_out, slice_out
            ), f"{ref_out.sum()} vs {slice_out.sum()}"
Example #13
0
def train(model, local_rank):
    log_path = 'train_log'
    if local_rank == 0:
        writer = SummaryWriter(log_path + '/train')
        writer_val = SummaryWriter(log_path + '/validate')
    else:
        writer, writer_val = None, None
    step = 0
    nr_eval = 0
    dataset = VimeoDataset('train')
    sampler = DistributedSampler(dataset)
    train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler)
    args.step_per_epoch = train_data.__len__()
    dataset_val = VimeoDataset('validation')
    val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=8)
    evaluate(model, val_data, nr_eval, local_rank, writer_val)
    model.save_model(log_path, local_rank)
    print('training...')
    time_stamp = time.time()
    for epoch in range(args.epoch):
        sampler.set_epoch(epoch)
        for i, data in enumerate(train_data):
            data_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            data_gpu, flow_gt = data
            data_gpu = data_gpu.to(device, non_blocking=True) / 255.
            flow_gt = flow_gt.to(device, non_blocking=True)
            imgs = data_gpu[:, :6]
            gt = data_gpu[:, 6:9]
            mul = np.cos(step / (args.epoch * args.step_per_epoch) * math.pi) * 0.5 + 0.5
            learning_rate = get_learning_rate(step)
            pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, learning_rate, mul, True, flow_gt)
            train_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            if step % 100 == 1 and local_rank == 0:
                writer.add_scalar('learning_rate', learning_rate, step)
                writer.add_scalar('loss_l1', loss_l1, step)
                writer.add_scalar('loss_flow', loss_flow, step)
                writer.add_scalar('loss_cons', loss_cons, step)
                writer.add_scalar('loss_ter', loss_ter, step)
            if step % 1000 == 1 and local_rank == 0:
                gt = (gt.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                pred = (pred.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                merged_img = (merged_img.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                flow = flow.permute(0, 2, 3, 1).detach().cpu().numpy()
                flow_mask = flow_mask.permute(0, 2, 3, 1).detach().cpu().numpy()
                flow_gt = flow_gt.permute(0, 2, 3, 1).detach().cpu().numpy()
                for i in range(5):
                    imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
                    writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC')
                    writer.add_image(str(i) + '/flow', flow2rgb(flow[i]), step, dataformats='HWC')
                    writer.add_image(str(i) + '/flow_gt', flow2rgb(flow_gt[i]), step, dataformats='HWC')
                    writer.add_image(str(i) + '/flow_mask', flow2rgb(flow[i] * flow_mask[i]), step, dataformats='HWC')
                writer.flush()
            if local_rank == 0:
                print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, loss_l1))
            step += 1
        nr_eval += 1
        if nr_eval % 5 == 0:
            evaluate(model, val_data, step, local_rank, writer_val)
        model.save_model(log_path, local_rank)    
        dist.barrier()
Example #14
0
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            best_valid = 0.

            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(log_dir=self.args.log_dir)

            hparam_dict = {}
            for k, v in self.args.__dict__.items():
                if type(v) in [int, float, str, bool, torch.Tensor]:
                    hparam_dict[k] = v
            metric_dict = {}

            self.writer.add_hparams(hparam_dict, metric_dict)

        if self.args.distributed:
            dist.barrier()

        self.optim.zero_grad()

        for epoch in range(self.args.epochs):
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=150)

                results = np.zeros(4, dtype=int)
                quesid2ans = {}

            for step_i, batch in enumerate(self.train_loader):
                vis_feats = batch['vis_feats'].cuda()
                boxes = batch['boxes'].cuda()

                ques_id = batch['question_ids']
                B = len(ques_id)

                input_ids = batch['word_ids'].cuda()
                input_ids = input_ids.unsqueeze(1).repeat(1, 2,
                                                          1).view(B * 2, -1)
                label = batch['labels'].cuda()

                results = self.model(
                    input_ids=input_ids,
                    visual_feats=vis_feats,
                    visual_pos=boxes,
                    attention_mask=input_ids > 0,
                )

                logit = results['logit']

                loss = self.mce_loss(logit, label)

                loss.backward()

                update = True
                if self.args.update_freq > 1:
                    if step_i == 0:
                        update = False
                    elif step_i % self.args.update_freq == 0 or step_i == len(
                            self.train_loader) - 1:
                        update = True
                    else:
                        update = False

                if update:
                    if not self.args.no_clip_grad:
                        nn.utils.clip_grad_norm_(self.model.parameters(),
                                                 self.args.clip_grad_norm)

                    self.optim.step()
                    self.lr_scheduler.step()
                    for param in self.model.parameters():
                        param.grad = None

                try:
                    lr = self.scheduler.get_last_lr()[0]
                except AttributeError:
                    lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | '
                    desc_str += f'Loss {loss_meter.val:4f} |'

                    score, predict = logit.max(1)
                    predict = predict.cpu().numpy()
                    label = label.cpu().numpy()

                    for qid, pred in zip(ques_id, predict):
                        quesid2ans[qid] = pred

                    results[0] += sum((label == 1) & (predict == 1))
                    results[1] += sum((label == 1) & (predict == 0))
                    results[2] += sum((label == 0) & (predict == 1))
                    results[3] += sum((label == 0) & (predict == 0))
                    n_total = sum(results)

                    desc_str += f' TP {results[0]} ({results[0]/n_total*100:.1f}%)'
                    desc_str += f' FN {results[1]} ({results[1]/n_total*100:.1f}%)'
                    desc_str += f' FP {results[2]} ({results[2]/n_total*100:.1f}%)'
                    desc_str += f' TN {results[3]} ({results[3]/n_total*100:.1f}%)'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()
                score = self.train_loader.evaluator.evaluate(quesid2ans) * 100.
                log_str = "\nEpoch %d: Train %0.2f" % (epoch, score)

                if not self.args.dry:
                    self.writer.add_scalar(f'NLVR/Train/score', score, epoch)

                # Validation
                valid_score = self.evaluate(self.val_loader) * 100.
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score)
                log_str += "\nEpoch %d: Best %0.2f\n" % (epoch, best_valid)

                if not self.args.dry:
                    self.writer.add_scalar(f'NLVR/Valid/score', valid_score,
                                           epoch)

                print(log_str)
                self.logger.info(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")
Example #15
0
    def _train_epoch(self, epoch):
        '''
        Training logic for an epoch
        :param epoch: Integer, current training epoch.
        :return: A log dict that contains average loss and metric in this epoch.
        '''
        self.model.train()
        self.train_loss_metrics.reset()
        ## step iteration start ##
        for step_idx, input_data_item in enumerate(self.data_loader):
            try:
                step_idx += 1
                for key, input_value in input_data_item.items():
                    if input_value is not None and isinstance(
                            input_value, torch.Tensor):
                        input_data_item[key] = input_value.to(
                            self.device, non_blocking=True)
                if self.config['trainer']['anomaly_detection']:
                    # This mode will increase the runtime and should only be enabled for debugging
                    with torch.autograd.detect_anomaly():
                        self.optimizer.zero_grad()
                        # model forward
                        output = self.model(**input_data_item)
                        # calculate loss
                        gl_loss = output['gl_loss']
                        crf_loss = output['crf_loss']
                        total_loss = torch.sum(
                            crf_loss
                        ) + self.gl_loss_lambda * torch.sum(gl_loss)
                        # backward
                        total_loss.backward()
                        # self.average_gradients(self.model)
                        self.optimizer.step()
                else:
                    self.optimizer.zero_grad()
                    # model forward
                    output = self.model(**input_data_item)
                    # calculate loss
                    gl_loss = output['gl_loss']
                    crf_loss = output['crf_loss']
                    total_loss = torch.sum(
                        crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                    # backward
                    total_loss.backward()
                    # self.average_gradients(self.model)
                    self.optimizer.step()

                # Use a barrier() to make sure that all process have finished forward and backward
                if self.distributed:
                    dist.barrier()
                    #  obtain the sum of all total_loss at all processes
                    dist.all_reduce(total_loss, op=dist.reduce_op.SUM)

                    size = dist.get_world_size()
                else:
                    size = 1
                gl_loss /= size  # averages gl_loss across the whole world
                crf_loss /= size  # averages crf_loss across the whole world

                # calculate average loss across the batch size
                avg_gl_loss = torch.mean(gl_loss)
                avg_crf_loss = torch.mean(crf_loss)
                avg_loss = avg_crf_loss + self.gl_loss_lambda * avg_gl_loss
                # update metrics
                self.writer.set_step((epoch - 1) * self.len_step + step_idx -
                                     1) if self.local_master else None
                self.train_loss_metrics.update('loss', avg_loss.item())
                self.train_loss_metrics.update(
                    'gl_loss',
                    avg_gl_loss.item() * self.gl_loss_lambda)
                self.train_loss_metrics.update('crf_loss', avg_crf_loss.item())

                # log messages
                if step_idx % self.log_step == 0:
                    self.logger_info(
                        'Train Epoch:[{}/{}] Step:[{}/{}] Total Loss: {:.6f} GL_Loss: {:.6f} CRF_Loss: {:.6f}'
                        .format(
                            epoch, self.epochs, step_idx, self.len_step,
                            self.train_loss_metrics.avg('loss'),
                            self.train_loss_metrics.avg('gl_loss') *
                            self.gl_loss_lambda,
                            self.train_loss_metrics.avg('crf_loss')))
                    # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

                # do validation after val_step_interval iteration
                if self.do_validation and step_idx % self.val_step_interval == 0:
                    val_result_dict = self._valid_epoch(epoch)
                    self.logger_info(
                        '[Step Validation] Epoch:[{}/{}] Step:[{}/{}]  \n{}'.
                        format(
                            epoch, self.epochs, step_idx, self.len_step,
                            SpanBasedF1MetricTracker.dict2str(
                                val_result_dict)))

                    # check if best metric, if true, then save as model_best checkpoint.
                    best, not_improved_count = self._is_best_monitor_metric(
                        False, 0, val_result_dict)
                    if best:
                        self._save_checkpoint(epoch, best)

                # decide whether continue iter
                if step_idx == self.len_step + 1:
                    break
            except Exception as e:
                print('OOM')

        ## step iteration end ##

        # {'loss': avg_loss, 'gl_loss': avg_gl_loss, 'crf_loss': avg_crf_loss}
        log = self.train_loss_metrics.result()
        self.writer.set_step(epoch, 'train')
        self.writer.add_scalar('total_loss_epoch', log['loss'])
        self.writer.add_scalar('gl_loss_epoch', log['gl_loss'])
        self.writer.add_scalar('crf_loss_epoch', log['crf_loss'])
        # do validation after training an epoch
        if self.do_validation:
            val_result_dict = self._valid_epoch(epoch)
            log['val_result_dict'] = val_result_dict
            self.writer.set_step(epoch, 'valid')
            self.writer.add_scalars(
                'mEF_valid', {
                    'total': val_result_dict['total']['mEF'],
                    'date': val_result_dict['date']['mEF'],
                    'company': val_result_dict['company']['mEF'],
                    'address': val_result_dict['address']['mEF']
                })

            # self.writer.add_scalars('mEF_valid',{'total':val_result_dict['total']['mEF'],
            #                                      'product':val_result_dict['product']['mEF'],
            #                                      'price':val_result_dict['price']['mEF']})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log
 def block(self):
     return
     if not self.distributed:
         return
     self.logger.info('blocking')
     dist.barrier()
 def barrier(self, name: str = None):
     if torch_distrib.is_initialized():
         torch_distrib.barrier()
Example #18
0
    def get_data_loader(
        self,
        config: Coqpit,
        ap: AudioProcessor,
        is_eval: bool,
        data_items: List,
        verbose: bool,
        num_gpus: int,
        rank: int = None,
    ) -> "DataLoader":
        if is_eval and not config.run_eval:
            loader = None
        else:
            # setup multi-speaker attributes
            if hasattr(self, "speaker_manager"):
                speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
                d_vector_mapping = (
                    self.speaker_manager.d_vectors
                    if config.use_speaker_embedding and config.use_d_vector_file
                    else None
                )
            else:
                speaker_id_mapping = None
                d_vector_mapping = None

            # setup custom symbols if needed
            custom_symbols = None
            if hasattr(self, "make_symbols"):
                custom_symbols = self.make_symbols(self.config)

            # init dataset
            dataset = TTSDataset(
                outputs_per_step=config.r if "r" in config else 1,
                text_cleaner=config.text_cleaner,
                compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
                compute_f0=config.get("compute_f0", False),
                f0_cache_path=config.get("f0_cache_path", None),
                meta_data=data_items,
                ap=ap,
                characters=config.characters,
                custom_symbols=custom_symbols,
                add_blank=config["add_blank"],
                return_wav=config.return_wav if "return_wav" in config else False,
                batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
                min_seq_len=config.min_seq_len,
                max_seq_len=config.max_seq_len,
                phoneme_cache_path=config.phoneme_cache_path,
                use_phonemes=config.use_phonemes,
                phoneme_language=config.phoneme_language,
                enable_eos_bos=config.enable_eos_bos_chars,
                use_noise_augment=not is_eval,
                verbose=verbose,
                speaker_id_mapping=speaker_id_mapping,
                d_vector_mapping=d_vector_mapping
                if config.use_speaker_embedding and config.use_d_vector_file
                else None,
            )

            # pre-compute phonemes
            if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
                if hasattr(self, "eval_data_items") and is_eval:
                    dataset.items = self.eval_data_items
                elif hasattr(self, "train_data_items") and not is_eval:
                    dataset.items = self.train_data_items
                else:
                    # precompute phonemes for precise estimate of sequence lengths.
                    # otherwise `dataset.sort_items()` uses raw text lengths
                    dataset.compute_input_seq(config.num_loader_workers)

                    # TODO: find a more efficient solution
                    # cheap hack - store items in the model state to avoid recomputing when reinit the dataset
                    if is_eval:
                        self.eval_data_items = dataset.items
                    else:
                        self.train_data_items = dataset.items

            # halt DDP processes for the main process to finish computing the phoneme cache
            if num_gpus > 1:
                dist.barrier()

            # sort input sequences from short to long
            dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False))

            # compute pitch frames and write to files.
            if config.compute_f0 and rank in [None, 0]:
                if not os.path.exists(config.f0_cache_path):
                    dataset.pitch_extractor.compute_pitch(
                        ap, config.get("f0_cache_path", None), config.num_loader_workers
                    )

            # halt DDP processes for the main process to finish computing the F0 cache
            if num_gpus > 1:
                dist.barrier()

            # load pitch stats computed above by all the workers
            if config.compute_f0:
                dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))

            # sampler for DDP
            sampler = DistributedSampler(dataset) if num_gpus > 1 else None

            # init dataloader
            loader = DataLoader(
                dataset,
                batch_size=config.eval_batch_size if is_eval else config.batch_size,
                shuffle=False,
                collate_fn=dataset.collate_fn,
                drop_last=False,
                sampler=sampler,
                num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
                pin_memory=False,
            )
        return loader
 def early_stopping_should_stop(self, pl_module):
     stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
     dist.all_reduce(stop, op=dist.reduce_op.SUM)
     dist.barrier()
     should_stop = stop == self.trainer.world_size
     return should_stop
Example #20
0
def barrier():
    if dist.is_available() and dist.is_initialized():
        dist.barrier()
    else:
        return
Example #21
0
def task_train(model,
               buffer,
               lifelong_datasets,
               config,
               metadata,
               logbook,
               dist_args=None):
    distributed = dist_args is not None
    if distributed:
        gpu = dist_args["gpu"]
        rank = dist_args["rank"]
    else:
        gpu = None
        rank = 0

    best_checkpoint = {
        "model_state_dict": deepcopy(model.method_state_dict()),
        "best_modified_jaccard": 0
    }
    best_checkpoint_file = os.path.join(config['logging_path'],
                                        "best_checkpoint")
    if config['use_best_model']:
        if config['task_epoch'] > 0 and os.path.exists(best_checkpoint_file):
            if distributed:
                best_checkpoint = torch.load(
                    best_checkpoint_file,
                    map_location=f"cuda:{dist_args['gpu']}")
            else:
                best_checkpoint = torch.load(best_checkpoint_file)

    task_train_data = lifelong_datasets['train']
    if config["method"] == "agem":
        bsm = 0.0
        task_train_data_with_buffer = TaskDataMergedWithBuffer(
            buffer, task_train_data, buffer_sampling_multiplier=bsm)
    else:
        bsm = config["buffer_sampling_multiplier"]
        task_train_data_with_buffer = TaskDataMergedWithBuffer(
            buffer, task_train_data, buffer_sampling_multiplier=bsm)
    task_valid_data = lifelong_datasets['intask_valid']
    cur_task_id = task_train_data.cur_task_id

    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            task_train_data_with_buffer,
            num_replicas=dist_args["world_size"],
            rank=rank)
    else:
        train_sampler = None

    train_loader = data.DataLoader(task_train_data_with_buffer,
                                   batch_size=config["batch_size"],
                                   shuffle=(train_sampler is None),
                                   num_workers=config["num_workers"],
                                   pin_memory=True,
                                   sampler=train_sampler)
    valid_loader = data.DataLoader(task_valid_data,
                                   batch_size=config["batch_size"],
                                   shuffle=False,
                                   num_workers=config["num_workers"],
                                   pin_memory=True)

    if cur_task_id == 0:
        num_epochs = config['epochs_per_task'] * 2
        print_msg(
            f"Training for {num_epochs} epochs for the first task (double that the other tasks)"
        )
    else:
        num_epochs = config['epochs_per_task']
    print_msg(
        f"Starting training of task {cur_task_id} epoch {config['task_epoch']} till epoch {num_epochs}"
    )
    for epoch in range(config['task_epoch'], num_epochs):
        if distributed:
            train_sampler.set_epoch(epoch)
        start_time = time.time()
        log_dict = {}
        train_loss, train_metrics = epoch_train(model, train_loader, config,
                                                metadata, gpu, rank)
        log_dict[f"train_loss_{cur_task_id}"] = train_loss
        for metric in train_metrics.keys():
            log_dict[f"train_{metric}_{cur_task_id}"] = train_metrics[metric]

        valid_loss, valid_metrics = evaluate(model,
                                             valid_loader,
                                             config,
                                             metadata,
                                             test_mode=False,
                                             gpu=gpu)
        log_dict[f"valid_loss_{cur_task_id}"] = valid_loss
        for metric in valid_metrics.keys():
            log_dict[f"valid_{metric}_{cur_task_id}"] = valid_metrics[metric]

        model.net.eval()
        model.consolidate_epoch_knowledge(
            log_dict[f"valid_modified_jaccard_{cur_task_id}"],
            task_data=task_train_data,
            device=config["device"],
            batch_size=config["batch_size"])
        # If using the lmdb database, close it and open a new environment to kill active readers
        buffer.reset_lmdb_database()

        if config['use_best_model']:
            if log_dict[
                    f"valid_modified_jaccard_{cur_task_id}"] >= best_checkpoint[
                        "best_modified_jaccard"]:
                best_checkpoint["best_modified_jaccard"] = log_dict[
                    f"valid_modified_jaccard_{cur_task_id}"]
                best_checkpoint["model_state_dict"] = deepcopy(
                    model.method_state_dict())
            log_dict[
                f"best_valid_modified_jaccard_{cur_task_id}"] = best_checkpoint[
                    "best_modified_jaccard"]

        if distributed:
            dist.barrier()  #to calculate the time based on the slowest gpu
        end_time = time.time()
        log_dict[f"elapsed_time"] = round(end_time - start_time, 2)

        if rank == 0:
            utils.log(epoch, cur_task_id, log_dict, logbook)
            if distributed:
                dist.barrier()
                log_dict["rank"] = rank
                print_msg(log_dict)
        else:
            dist.barrier()
            log_dict["rank"] = rank
            print_msg(log_dict)

        # Checkpointing
        config["task_epoch"] = epoch + 1
        if (config["task_epoch"] %
                config['checkpoint_interval']) == 0 and rank == 0:
            print_msg("Saving latest checkpoint")
            save_file = os.path.join(config['logging_path'], "latest_model")
            lifelong_methods.utils.save_model(save_file, config, metadata,
                                              model, buffer, lifelong_datasets)
            if config['use_best_model']:
                print_msg("Saving best checkpoint")
                torch.save(best_checkpoint, best_checkpoint_file)

    # reset the model parameters to the best performing model
    if config['use_best_model']:
        model.load_method_state_dict(best_checkpoint["model_state_dict"])
def main():
    args = parse_args()
    
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(
        backend="nccl", init_method="env://"
    )
    dist.barrier()
    to_output = dist.get_rank() == 0
    
    if to_output:
        logger, final_output_dir, tb_log_dir = create_logger(
            config, args.cfg, 'train')

        logger.info(pprint.pformat(args))
        logger.info("This is config file")
        logger.info(pprint.pformat(config))
    else:
        final_output_dir, tb_log_dir = None, None

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.'+config.MODEL.NAME+'.get_cls_net')(
        config)
    if args.load_model:
        model = models.load_model(model, args.load_model)
    model = model.cuda()

    if to_output:
        dump_input = torch.rand(
            (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0])
        ).cuda()
        logger.info(get_model_summary(model, dump_input))

        # copy model file
        this_dir = os.path.dirname(__file__)
        models_dst_dir = os.path.join(final_output_dir, 'models')
        if os.path.exists(models_dst_dir):
            shutil.rmtree(models_dst_dir)
        shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)

        #writer_dict = {
        #    'writer': SummaryWriter(log_dir=tb_log_dir),
        #    'train_global_steps': 0,
        #    'valid_global_steps': 0,
        #}

    model = torch.nn.parallel.DistributedDataParallel(
      model, device_ids=[args.local_rank], output_device=args.local_rank
    )

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    optimizer = get_optimizer(config, model)

    best_perf = 0.0
    best_model = False
    last_epoch = config.TRAIN.BEGIN_EPOCH
    
    '''
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir,
                                        'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})"
                        .format(checkpoint['epoch']))
            best_model = True
    '''
            
    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch-1
        )
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch-1
        )

    # Data loading code
    traindir = os.path.join(config.DATASET.ROOT, config.DATASET.TRAIN_SET)
    valdir = os.path.join(config.DATASET.ROOT, config.DATASET.TEST_SET)
    print(traindir, valdir)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    )
    train_sampler = DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, 
        shuffle=False,
        sampler=train_sampler,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True
    )

    val_dataset = datasets.ImageFolder(
        valdir, 
        transforms.Compose([
            transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)),
            transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.ToTensor(),
            normalize,
        ])
    )
    val_sampler = DistributedSampler(val_dataset, shuffle=False)
    valid_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        sampler=val_sampler,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True
    )

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        if to_output:
            print('Epoch %d start'%(epoch))
        lr_scheduler.step()
        train_sampler.set_epoch(epoch)
        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, to_output=to_output)
        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, model, criterion,
                                  final_output_dir, tb_log_dir, to_output=to_output)

        if dist.get_rank() == 0:
            if perf_indicator > best_perf:
                best_perf = perf_indicator
                best_model = True
            else:
                best_model = False
                
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            save_checkpoint({
                'epoch': epoch + 1,
                'model': config.MODEL.NAME,
                'state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir, filename='checkpoint.pth.tar')

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info('saving final model state to {}'.format(
        final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
Example #23
0
def tasks_eval(model,
               dataset,
               cur_task_id,
               config,
               metadata,
               logbook,
               dataset_type="valid",
               dist_args=None):
    """log the accuracies of the new model on all observed tasks
    :param metadata:
    """
    assert dataset.complete_information_mode is True

    distributed = dist_args is not None
    if distributed:
        gpu = dist_args["gpu"]
        rank = dist_args["rank"]
    else:
        gpu = None
        rank = 0

    metrics_dict = {}
    for task_id in range(cur_task_id + 1):
        dataset.choose_task(task_id)
        dataloader = data.DataLoader(dataset,
                                     batch_size=config["batch_size"],
                                     shuffle=False,
                                     num_workers=config["num_workers"],
                                     pin_memory=True)
        _, metrics = evaluate(model,
                              dataloader,
                              config,
                              metadata,
                              test_mode=True,
                              gpu=gpu)
        for metric in metrics.keys():
            metrics_dict[f"task_{task_id}_{dataset_type}_{metric}"] = metrics[
                metric]
    dataset.load_tasks_up_to(cur_task_id)
    dataloader = data.DataLoader(dataset,
                                 batch_size=config["batch_size"],
                                 shuffle=False,
                                 num_workers=config["num_workers"],
                                 pin_memory=True)
    _, metrics = evaluate(model,
                          dataloader,
                          config,
                          metadata,
                          test_mode=True,
                          gpu=gpu)
    for metric in metrics.keys():
        metrics_dict[f"average_{dataset_type}_{metric}"] = metrics[metric]

    if rank == 0:
        utils.log_task(cur_task_id, metrics_dict, logbook)
        if distributed:
            dist.barrier()
            metrics_dict["rank"] = rank
            print_msg(metrics_dict)
    else:
        dist.barrier()
        metrics_dict["rank"] = rank
        print_msg(metrics_dict)
Example #24
0
def train(rank, experiment_name, world_size, continue_epoch, dist_url):
    print(f"Running rank {rank}/{world_size} dist url: {dist_url}.")
    # setup(rank, world_size)

    dist.init_process_group(backend="nccl", init_method=dist_url,
                            world_size=world_size, rank=rank)
    torch.cuda.set_device(rank)

    distributed = True
    model_str = experiment_name

    cfg = load_config_data(experiment_name)
    pprint.pprint(cfg)

    model_type = cfg["model_params"]["model_type"]
    train_params = DotDict(cfg["train_params"])

    checkpoints_dir = f"./checkpoints/{model_str}"
    tensorboard_dir = f"./tensorboard/{model_type}/{model_str}"
    oof_dir = f"./oof/{model_str}"
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(oof_dir, exist_ok=True)
    print("\n", experiment_name, "\n")

    logger = SummaryWriter(log_dir=tensorboard_dir)

    scaler = torch.cuda.amp.GradScaler()

    with utils.timeit_context("load train"):
        dataset_train = dataset.LyftDatasetPrerendered(dset_name=dataset.LyftDataset.DSET_TRAIN_XXL, cfg_data=cfg)

    with utils.timeit_context("load validation"):
        dataset_valid = dataset.LyftDatasetPrerendered(dset_name=dataset.LyftDataset.DSET_VALIDATION, cfg_data=cfg)

    batch_size = dataset_train.dset_cfg["batch_size"]

    # train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, num_replicas=world_size)

    data_loaders = {
        "train": DataLoader(
            dataset_train,
            num_workers=16,
            shuffle=True,
            # sampler=train_sampler,
            batch_size=batch_size // world_size
        ),
        "val": DataLoader(
            dataset_valid,
            shuffle=False,
            num_workers=16,
            batch_size=dataset_valid.dset_cfg["batch_size"] // world_size,
        ),
    }
    model_info = DotDict(cfg["model_params"])
    model_orig = build_model(model_info, cfg).cuda(rank)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model_orig)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)
    model.train()

    initial_lr = float(train_params.initial_lr)
    optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, nesterov=True)

    if continue_epoch > 0:
        # if rank == 0:
        fn = f"{checkpoints_dir}/{continue_epoch:03}.pt"
        print(f'loading {fn}...')
        # dist.barrier()

        map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
        checkpoint = torch.load(fn, map_location=map_location)

        # if distributed:
        #     model.module.load_state_dict(checkpoint["model_state_dict"])

        model.module.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        dist.barrier()
        print(f'loaded {fn}')

    nb_epochs = train_params.nb_epochs
    scheduler = utils.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=train_params.scheduler_period,
        T_mult=train_params.get('scheduler_t_mult', 1),
        eta_min=initial_lr / 1000.0,
        last_epoch=-1
    )
    for i in range(continue_epoch + 1):
        scheduler.step()

    grad_clip_value = train_params.get("grad_clip", 2.0)
    print("grad clip:", grad_clip_value)

    print(f"Num training agents: {len(dataset_train)} validation agents: {len(dataset_valid)}")

    for epoch_num in range(continue_epoch + 1, nb_epochs + 1):
        for phase in ["train", "val"]:
            model.train(phase == "train")
            epoch_loss_regression = []
            data_loader = data_loaders[phase]

            optimizer.zero_grad()

            if phase == "train":
                nb_steps_per_epoch = train_params.epoch_size // batch_size
                data_iter = tqdm(
                    utils.LoopIterable(data_loader, max_iters=nb_steps_per_epoch),
                    total=nb_steps_per_epoch,
                    ncols=250,
                    # disable=rank > 0
                )
            else:
                if epoch_num % 2 == 1:  # skip each second validation for speed
                    continue

                data_iter = tqdm(data_loader, ncols=250)

            for data in data_iter:
                with torch.set_grad_enabled(phase == "train"):
                    inputs = data["image"].float().cuda(rank, non_blocking=True)
                    target_availabilities = data["target_availabilities"].cuda(rank, non_blocking=True)
                    targets = data["target_positions"].cuda(rank, non_blocking=True)

                    optimizer.zero_grad()
                    loss_regression = 0
                    agent_state = None

                    if model_type == MODEL_TYPE_REGRESSION_MULTI_MODE:
                        with torch.cuda.amp.autocast():
                            pred, confidences = model(inputs)

                            loss_regression = utils.pytorch_neg_multi_log_likelihood_batch_from_log_sm(
                                gt=targets.float(),
                                pred=pred.float(),
                                confidences=confidences.float(),
                                avails=target_availabilities.float(),
                            )

                    if model_type == MODEL_TYPE_REGRESSION_MULTI_MODE_I4X:
                        with torch.cuda.amp.autocast():
                            pred, confidences = model(inputs, agent_state, data["image_4x"].float().cuda())

                            loss_regression = utils.pytorch_neg_multi_log_likelihood_batch(
                                gt=targets.float(),
                                pred=pred.float(),
                                confidences=confidences.float(),
                                avails=target_availabilities.float(),
                            )

                    loss = loss_regression

                    if phase == "train":
                        scaler.scale(loss).backward()

                        # Unscales the gradients of optimizer's assigned params in-place
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_value)
                        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
                        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
                        scaler.step(optimizer)
                        scaler.update()

                    if phase == "val":
                        # save predictions visualisation
                        pass

                    epoch_loss_regression.append(float(loss_regression))
                    loss_regression = None
                    del loss

                    data_iter.set_description(
                        f"{epoch_num} {phase[0]}"
                        f" Loss r {np.mean(epoch_loss_regression):1.4f} "
                    )

            if rank == 0:
                logger.add_scalar(f"loss_{phase}", np.mean(epoch_loss_regression), epoch_num)

                if phase == "train":
                    logger.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch_num)
                logger.flush()

            if phase == "train":
                scheduler.step()

                if rank == 0:
                    torch.save(
                        {
                            "epoch": epoch_num,
                            # "model_state_dict": model.state_dict(),
                            "model_state_dict": model.module.state_dict() if distributed else model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                        },
                        f"{checkpoints_dir}/{epoch_num:03}.pt",
                    )
Example #25
0
def PolarOffsetMain(args, cfg):
    if args.launcher == None:
        dist_train = False
    else:
        args.batch_size, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)(
            args.batch_size, args.tcp_port, args.local_rank, backend='nccl'
        )
        dist_train = True
    cfg['DIST_TRAIN'] = dist_train
    output_dir = os.path.join('./output', args.tag)
    ckpt_dir = os.path.join(output_dir, 'ckpt')
    tmp_dir = os.path.join(output_dir, 'tmp')
    summary_dir = os.path.join(output_dir, 'summary')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir, exist_ok=True)
    if not os.path.exists(summary_dir):
        os.makedirs(summary_dir, exist_ok=True)

    if args.onlyval and args.saveval:
        results_dir = os.path.join(output_dir, 'test', 'sequences')
        if not os.path.exists(results_dir):
            os.makedirs(results_dir, exist_ok=True)
        for i in range(8, 9):
            sub_dir = os.path.join(results_dir, str(i).zfill(2), 'predictions')
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir, exist_ok=True)

    if args.onlytest:
        results_dir = os.path.join(output_dir, 'test', 'sequences')
        if not os.path.exists(results_dir):
            os.makedirs(results_dir, exist_ok=True)
        for i in range(11,22):
            sub_dir = os.path.join(results_dir, str(i).zfill(2), 'predictions')
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir, exist_ok=True)

    log_file = os.path.join(output_dir, ('log_train_%s.txt' % datetime.datetime.now().strftime('%Y%m%d-%H%M%S')))
    logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK)

    logger.info('**********************Start logging**********************')
    gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
    logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)

    if dist_train:
        total_gpus = dist.get_world_size()
        logger.info('total_batch_size: %d' % (total_gpus * args.batch_size))
    for key, val in vars(args).items():
        logger.info('{:16} {}'.format(key, val))
    log_config_to_file(cfg, logger=logger)
    if cfg.LOCAL_RANK == 0:
        os.system('cp %s %s' % (args.config, output_dir))

    ### create dataloader
    if (not args.onlytest) and (not args.onlyval):
        train_dataset_loader = build_dataloader(args, cfg, split='train', logger=logger)
        val_dataset_loader = build_dataloader(args, cfg, split='val', logger=logger, no_shuffle=True, no_aug=True)
    elif args.onlyval:
        val_dataset_loader = build_dataloader(args, cfg, split='val', logger=logger, no_shuffle=True, no_aug=True)
    else:
        test_dataset_loader = build_dataloader(args, cfg, split='test', logger=logger, no_shuffle=True, no_aug=True)

    ### create model
    model = build_network(cfg)
    model.cuda()

    ### create optimizer
    optimizer = train_utils.build_optimizer(model, cfg)

    ### load ckpt
    ckpt_fname = os.path.join(ckpt_dir, args.ckpt_name)
    epoch = -1

    other_state = {}
    if args.pretrained_ckpt is not None and os.path.exists(ckpt_fname):
        logger.info("Now in pretrain mode and loading ckpt: {}".format(ckpt_fname))
        if not args.nofix:
            if args.fix_semantic_instance:
                logger.info("Freezing backbone, semantic and instance part of the model.")
                model.fix_semantic_instance_parameters()
            else:
                logger.info("Freezing semantic and backbone part of the model.")
                model.fix_semantic_parameters()
        optimizer = train_utils.build_optimizer(model, cfg)
        epoch, other_state = train_utils.load_params_with_optimizer_otherstate(model, ckpt_fname, to_cpu=dist_train, optimizer=optimizer, logger=logger) # new feature
        logger.info("Loaded Epoch: {}".format(epoch))
    elif args.pretrained_ckpt is not None:
        train_utils.load_pretrained_model(model, args.pretrained_ckpt, to_cpu=dist_train, logger=logger)
        if not args.nofix:
            if args.fix_semantic_instance:
                logger.info("Freezing backbone, semantic and instance part of the model.")
                model.fix_semantic_instance_parameters()
            else:
                logger.info("Freezing semantic and backbone part of the model.")
                model.fix_semantic_parameters()
        else:
            logger.info("No Freeze.")
        optimizer = train_utils.build_optimizer(model, cfg)
    elif os.path.exists(ckpt_fname):
        epoch, other_state = train_utils.load_params_with_optimizer_otherstate(model, ckpt_fname, to_cpu=dist_train, optimizer=optimizer, logger=logger) # new feature
        logger.info("Loaded Epoch: {}".format(epoch))
    if other_state is None:
        other_state = {}

    ### create optimizer and scheduler
    lr_scheduler = None
    if lr_scheduler == None:
        logger.info('Not using lr scheduler')

    model.train()  # before wrap to DistributedDataParallel to support fixed some parameters
    if dist_train:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()], find_unused_parameters=True)
    logger.info(model)

    if cfg.LOCAL_RANK==0:
        writer = SummaryWriter(log_dir=summary_dir)

    logger.info('**********************Start Training**********************')
    rank = cfg.LOCAL_RANK
    best_before_iou = -1 if 'best_before_iou' not in other_state else other_state['best_before_iou']
    best_pq = -1 if 'best_pq' not in other_state else other_state['best_pq']
    best_after_iou = -1 if 'best_after_iou' not in other_state else other_state['best_after_iou']
    global_iter = 0 if 'global_iter' not in other_state else other_state['global_iter']
    val_global_iter = 0 if 'val_global_iter' not in other_state else other_state ['val_global_iter']

    ### test
    if args.onlytest:
        logger.info('----EPOCH {} Testing----'.format(epoch))
        model.eval()
        if rank == 0:
            vbar = tqdm(total=len(test_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(test_dataset_loader):
            with torch.no_grad():
                ret_dict = model(inputs, is_test=True, require_cluster=True, require_merge=True)
                common_utils.save_test_results(ret_dict, results_dir, inputs)
            if rank == 0:
                vbar.set_postfix({'fname':'/'.join(inputs['pcd_fname'][0].split('/')[-3:])})
                vbar.update(1)
        if rank == 0:
            vbar.close()
        logger.info("----Testing Finished----")
        return
    
    ### evaluate
    if args.onlyval:
        logger.info('----EPOCH {} Evaluating----'.format(epoch))
        model.eval()
        min_points = 50 # according to SemanticKITTI official rule
        before_merge_evaluator = init_eval(min_points=min_points)
        after_merge_evaluator = init_eval(min_points=min_points)

        if rank == 0:
            vbar = tqdm(total=len(val_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(val_dataset_loader):
            inputs['i_iter'] = i_iter
            torch.cuda.empty_cache()
            with torch.no_grad():
                ret_dict = model(inputs, is_test=True, before_merge_evaluator=before_merge_evaluator,
                                after_merge_evaluator=after_merge_evaluator, require_cluster=True)
                if args.saveval:
                    common_utils.save_test_results(ret_dict, results_dir, inputs)
            if rank == 0:
                vbar.set_postfix({'loss': ret_dict['loss'].item(),
                                  'fname':'/'.join(inputs['pcd_fname'][0].split('/')[-3:]),
                                  'ins_num': -1 if 'ins_num' not in ret_dict else ret_dict['ins_num']})
                vbar.update(1)
        if dist_train:
            before_merge_evaluator = common_utils.merge_evaluator(before_merge_evaluator, tmp_dir)
            dist.barrier()
            after_merge_evaluator = common_utils.merge_evaluator(after_merge_evaluator, tmp_dir)

        if rank == 0:
            vbar.close()
        if rank == 0:
            ## print results
            logger.info("Before Merge Semantic Scores")
            before_merge_results = printResults(before_merge_evaluator, logger=logger, sem_only=True)
            logger.info("After Merge Panoptic Scores")
            after_merge_results = printResults(after_merge_evaluator, logger=logger)

        logger.info("----Evaluating Finished----")
        return
    
    ### train
    while True:
        epoch += 1
        if 'MAX_EPOCH' in cfg.OPTIMIZE.keys():
            if epoch > cfg.OPTIMIZE.MAX_EPOCH:
                break

        ### train one epoch
        logger.info('----EPOCH {} Training----'.format(epoch))
        loss_acc = 0
        if rank == 0:
            pbar = tqdm(total=len(train_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(train_dataset_loader):
            torch.cuda.empty_cache()
            torch.autograd.set_detect_anomaly(True)
            model.train()
            optimizer.zero_grad()
            inputs['i_iter'] = i_iter
            inputs['rank'] = rank
            ret_dict = model(inputs)
            if args.pretrained_ckpt is not None and not args.fix_semantic_instance: # training offset
                if len(ret_dict['offset_loss_list']) > 0:
                    loss = sum(ret_dict['offset_loss_list'])
                else:
                    loss = torch.tensor(0.0, requires_grad=True) #mock pbar
                    ret_dict['offset_loss_list'] = [loss] #mock writer
            elif args.pretrained_ckpt is not None and args.fix_semantic_instance: # training dynamic shifting
                loss = sum(ret_dict['meanshift_loss'])
            else:
                loss = ret_dict['loss']
            loss.backward()
            optimizer.step()

            torch.cuda.empty_cache()
            if rank == 0:
                try:
                    cur_lr = float(optimizer.lr)
                except:
                    cur_lr = optimizer.param_groups[0]['lr']
                loss_acc += loss.item()
                pbar.set_postfix({'loss': loss.item(), 'lr': cur_lr, 'mean_loss': loss_acc / float(i_iter+1)})
                pbar.update(1)
                writer.add_scalar('Train/01_Loss', ret_dict['loss'].item(), global_iter)
                writer.add_scalar('Train/02_SemLoss', ret_dict['sem_loss'].item(), global_iter)
                if sum(ret_dict['offset_loss_list']).item() > 0:
                    writer.add_scalar('Train/03_InsLoss', sum(ret_dict['offset_loss_list']).item(), global_iter)
                writer.add_scalar('Train/04_LR', cur_lr, global_iter)
                writer_acc = 5
                if 'meanshift_loss' in ret_dict:
                    writer.add_scalar('Train/05_DSLoss', sum(ret_dict['meanshift_loss']).item(), global_iter)
                    writer_acc += 1
                more_keys = []
                for k, _ in ret_dict.items():
                    if k.find('summary') != -1:
                        more_keys.append(k)
                for ki, k in enumerate(more_keys):
                    if k == 'bandwidth_weight_summary':
                        continue
                    ki += writer_acc
                    writer.add_scalar('Train/{}_{}'.format(str(ki).zfill(2), k), ret_dict[k], global_iter)
                global_iter += 1
        if rank == 0:
            pbar.close()
        
        ### evaluate after each epoch
        logger.info('----EPOCH {} Evaluating----'.format(epoch))
        model.eval()
        min_points = 50
        before_merge_evaluator = init_eval(min_points=min_points)
        after_merge_evaluator = init_eval(min_points=min_points)
        if rank == 0:
            vbar = tqdm(total=len(val_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(val_dataset_loader):
            torch.cuda.empty_cache()
            inputs['i_iter'] = i_iter
            inputs['rank'] = rank
            with torch.no_grad():
                ret_dict = model(inputs, is_test=True, before_merge_evaluator=before_merge_evaluator,
                                after_merge_evaluator=after_merge_evaluator, require_cluster=True)
            if rank == 0:
                vbar.set_postfix({'loss': ret_dict['loss'].item()})
                vbar.update(1)
                writer.add_scalar('Val/01_Loss', ret_dict['loss'].item(), val_global_iter)
                writer.add_scalar('Val/02_SemLoss', ret_dict['sem_loss'].item(), val_global_iter)
                if sum(ret_dict['offset_loss_list']).item() > 0:
                    writer.add_scalar('Val/03_InsLoss', sum(ret_dict['offset_loss_list']).item(), val_global_iter)
                more_keys = []
                for k, _ in ret_dict.items():
                    if k.find('summary') != -1:
                        more_keys.append(k)
                for ki, k in enumerate(more_keys):
                    if k == 'bandwidth_weight_summary':
                        continue
                    ki += 4
                    writer.add_scalar('Val/{}_{}'.format(str(ki).zfill(2), k), ret_dict[k], val_global_iter)
                val_global_iter += 1
        if dist_train:
            try:
                before_merge_evaluator = common_utils.merge_evaluator(before_merge_evaluator, tmp_dir, prefix='before_')
                dist.barrier()
                after_merge_evaluator = common_utils.merge_evaluator(after_merge_evaluator, tmp_dir, prefix='after_')
            except:
                print("Someting went wrong when merging evaluator in rank {}".format(rank))
        if rank == 0:
            vbar.close()
        if rank == 0:
            ## print results
            logger.info("Before Merge Semantic Scores")
            before_merge_results = printResults(before_merge_evaluator, logger=logger, sem_only=True)
            logger.info("After Merge Panoptic Scores")
            after_merge_results = printResults(after_merge_evaluator, logger=logger)
            ## save ckpt
            other_state = {
                'best_before_iou': best_before_iou,
                'best_pq': best_pq,
                'best_after_iou': best_after_iou,
                'global_iter': global_iter,
                'val_global_iter': val_global_iter
            }
            saved_flag = False
            if best_before_iou < before_merge_results['iou_mean']:
                best_before_iou = before_merge_results['iou_mean']
                if not saved_flag:
                    states = train_utils.checkpoint_state(model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(states, os.path.join(ckpt_dir,
                        'checkpoint_epoch_{}_{}_{}_{}.pth'.format(epoch, str(best_before_iou)[:5], str(best_pq)[:5], str(best_after_iou)[:5])))
                    saved_flag = True
            if best_pq < after_merge_results['pq_mean']:
                best_pq = after_merge_results['pq_mean']
                if not saved_flag:
                    states = train_utils.checkpoint_state(model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(states, os.path.join(ckpt_dir,
                        'checkpoint_epoch_{}_{}_{}_{}.pth'.format(epoch, str(best_before_iou)[:5], str(best_pq)[:5], str(best_after_iou)[:5])))
                    saved_flag = True
            if best_after_iou < after_merge_results['iou_mean']:
                best_after_iou = after_merge_results['iou_mean']
                if not saved_flag:
                    states = train_utils.checkpoint_state(model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(states, os.path.join(ckpt_dir,
                        'checkpoint_epoch_{}_{}_{}_{}.pth'.format(epoch, str(best_before_iou)[:5], str(best_pq)[:5], str(best_after_iou)[:5])))
                    saved_flag = True
            logger.info("Current best before IoU: {}".format(best_before_iou))
            logger.info("Current best after IoU: {}".format(best_after_iou))
            logger.info("Current best after PQ: {}".format(best_pq))
        if lr_scheduler != None:
            lr_scheduler.step(epoch) # new feature
Example #26
0
    def synthesize(self,
                   num,
                   z=None,
                   html_name=None,
                   save_raw_synthesis=False):
        """Synthesizes images.

        Args:
            num: Number of images to synthesize.
            z: Latent codes used for generation. If not specified, this function
                will sample latent codes randomly. (default: None)
            html_name: Name of the output html page for visualization. If not
                specified, no visualization page will be saved. (default: None)
            save_raw_synthesis: Whether to save raw synthesis on the disk.
                (default: False)
        """
        if not html_name and not save_raw_synthesis:
            return

        self.set_mode('val')

        temp_dir = os.path.join(self.work_dir, 'synthesize_results')
        os.makedirs(temp_dir, exist_ok=True)

        if z is not None:
            assert isinstance(z, np.ndarray)
            assert z.ndim == 2 and z.shape[1] == self.z_space_dim
            num = min(num, z.shape[0])
            z = torch.from_numpy(z).type(torch.FloatTensor)
        if not num:
            return
        # TODO: Use same z during the entire training process.

        self.logger.init_pbar()
        task1 = self.logger.add_pbar_task('Synthesize', total=num)

        indices = list(range(self.rank, num, self.world_size))
        for batch_idx in range(0, len(indices), self.val_batch_size):
            sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
            batch_size = len(sub_indices)
            if z is None:
                code = torch.randn(batch_size, self.z_space_dim).cuda()
            else:
                code = z[sub_indices].cuda()
            with torch.no_grad():
                if 'generator_smooth' in self.models:
                    G = self.models['generator_smooth']
                else:
                    G = self.models['generator']
                images = G(code, **self.G_kwargs_val)['image']
                images = self.postprocess(images)
            for sub_idx, image in zip(sub_indices, images):
                save_image(os.path.join(temp_dir, f'{sub_idx:06d}.jpg'), image)
            self.logger.update_pbar(task1, batch_size * self.world_size)

        dist.barrier()
        if self.rank != 0:
            return

        if html_name:
            task2 = self.logger.add_pbar_task('Visualize', total=num)
            html = HtmlPageVisualizer(grid_size=num)
            for image_idx in range(num):
                image = load_image(
                    os.path.join(temp_dir, f'{image_idx:06d}.jpg'))
                row_idx, col_idx = divmod(image_idx, html.num_cols)
                html.set_cell(row_idx,
                              col_idx,
                              image=image,
                              text=f'Sample {image_idx:06d}')
                self.logger.update_pbar(task2, 1)
            html.save(os.path.join(self.work_dir, html_name))
        if not save_raw_synthesis:
            shutil.rmtree(temp_dir)

        self.logger.close_pbar()
Example #27
0
 def barrier(self, *args, **kwargs):
     if torch_distrib.is_initialized():
         torch_distrib.barrier()
def single_process_function(rank, 
                            world_size, 
                            lr, 
                            model, 
                            inputs, 
                            labels, 
                            loss_fn, 
                            miner_fn, 
                            original_model, 
                            original_loss_fn, 
                            original_miner_fn, 
                            correct_loss, 
                            correct_indices_tuple,
                            is_tuple_loss,
                            ref_outputs,
                            ref_labels):
    setup(rank, world_size)
    device = torch.device("cuda:{}".format(rank))

    ddp_mp_model = DDP(model.to(device), device_ids=[rank], output_device=rank)

    if is_tuple_loss:
        loss_fn = distributed.DistributedLossWrapper(loss=loss_fn)
    else:
        loss_fn = distributed.DistributedLossWrapper(loss=loss_fn.to(device), device_ids=[rank], output_device=rank)
        loss_optimizer = optim.SGD(loss_fn.parameters(), lr=lr)
        loss_optimizer.zero_grad()

    miner_fn = distributed.DistributedMinerWrapper(miner=miner_fn)

    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=lr)
    optimizer.zero_grad()
    outputs = ddp_mp_model(inputs[rank].to(device))

    if ref_outputs is not None:
        ref_outputs[rank] = ref_outputs[rank].to(device)
        indices_tuple = miner_fn(outputs, labels[rank], ref_outputs[rank], ref_labels[rank])
        indices_tuple = c_f.shift_indices_tuple(indices_tuple, len(outputs)*world_size)
        loss = loss_fn([outputs, ref_outputs[rank]], [labels[rank], ref_labels[rank]], indices_tuple)
    else:
        indices_tuple = miner_fn(outputs, labels[rank])
        loss = loss_fn(outputs, labels[rank], indices_tuple)

    if is_tuple_loss:
        pos_loss_size = loss_fn.loss.reducer.reducers["pos_loss"].losses_size
        neg_loss_size = loss_fn.loss.reducer.reducers["neg_loss"].losses_size
        correct_pos_loss_size = original_loss_fn.reducer.reducers["pos_loss"].losses_size
        correct_neg_loss_size = original_loss_fn.reducer.reducers["neg_loss"].losses_size
        assert pos_loss_size == correct_pos_loss_size
        assert neg_loss_size == correct_neg_loss_size
    else:
        loss_size = loss_fn.loss.module.reducer.losses_size
        correct_loss_size = original_loss_fn.reducer.losses_size
        assert loss_size == correct_loss_size

    assert torch.isclose(loss, torch.from_numpy(correct_loss).to(device))
    assert miner_fn.miner.num_pos_pairs == original_miner_fn.num_pos_pairs
    assert miner_fn.miner.num_neg_pairs == original_miner_fn.num_neg_pairs
    for i in range(len(correct_indices_tuple)):
        assert torch.all(indices_tuple[i] == (torch.from_numpy(correct_indices_tuple[i]).to(device)))

    dist.barrier()
    loss.backward()

    original_model = original_model.to(device)
    assert not parameters_are_equal(original_model, ddp_mp_model.module)
    dist.barrier()
    optimizer.step()
    dist.barrier()
    assert parameters_are_equal(original_model, ddp_mp_model.module)

    if not is_tuple_loss:
        original_loss_fn = original_loss_fn.to(device)
        assert not parameters_are_equal(original_loss_fn, loss_fn.loss.module)
        dist.barrier()
        loss_optimizer.step()
        dist.barrier()
        assert parameters_are_equal(original_loss_fn, loss_fn.loss.module)

    dist.barrier()
    cleanup()
Example #29
0
    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
        dataset = dataloader.dataset
        dataset_name = dataset.opt['name']
        with_metrics = self.opt['val']['metrics'] is not None
        # initialize self.metric_results
        # It is a dict: {
        #    'folder1': tensor (num_frame x len(metrics)),
        #    'folder2': tensor (num_frame x len(metrics))
        # }
        if with_metrics and not hasattr(self, 'metric_results'):
            self.metric_results = {}
            num_frame_each_folder = Counter(dataset.data_info['folder'])
            for folder, num_frame in num_frame_each_folder.items():
                self.metric_results[folder] = torch.zeros(
                    num_frame,
                    len(self.opt['val']['metrics']),
                    dtype=torch.float32,
                    device='cuda')

        rank, world_size = get_dist_info()
        for _, tensor in self.metric_results.items():
            tensor.zero_()
        # record all frames (border and center frames)
        if rank == 0:
            pbar = ProgressBar(len(dataset))
        for idx in range(rank, len(dataset), world_size):
            val_data = dataset[idx]
            val_data['lq'].unsqueeze_(0)
            val_data['gt'].unsqueeze_(0)
            folder = val_data['folder']
            frame_idx, max_idx = val_data['idx'].split('/')
            lq_path = val_data['lq_path']

            self.feed_data(val_data)
            self.test()
            visuals = self.get_current_visuals()

            result_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    raise NotImplementedError(
                        'saving image is not supported during training.')
                else:
                    if 'vimeo' in dataset_name.lower():  # vimeo90k dataset
                        split_result = lq_path.split('/')
                        img_name = (f'{split_result[-3]}_{split_result[-2]}_'
                                    f'{split_result[-1].split(".")[0]}')
                    else:  # other datasets, e.g., REDS, Vid4

                        img_name = osp.splitext(osp.basename(lq_path))[0]

                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            folder,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        split_result = lq_path.split('/')
                        img_name = (f'{split_result[-3]}_{split_result[-2]}_'
                                    f'{split_result[-1].split(".")[0]}')
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], folder,
                            f'{img_name}.png')

                np_save_img_path = save_img_path.replace('png', 'npy')
                if not os.path.exists(
                        osp.join(self.opt['path']['visualization'], folder)):
                    os.makedirs(
                        osp.join(self.opt['path']['visualization'], folder))
                np.save(
                    np_save_img_path,
                    np.array([
                        visuals['embedding_gt'], visuals['embedding_out'],
                        visuals['embedding_center']
                    ]))
                mmcv.imwrite(result_img, save_img_path)
            split_result = lq_path.split('/')
            img_name = (f'{split_result[-3]}_{split_result[-2]}_'
                        f'{split_result[-1].split(".")[0]}')
            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for metric_idx, opt_ in enumerate(opt_metric.values()):
                    out_emb = visuals['embedding_out']
                    gt_emb = visuals['embedding_gt']

                    gt = gt_emb / np.sqrt(np.sum(gt_emb**2, -1, keepdims=True))
                    out = out_emb / np.sqrt(
                        np.sum(out_emb**2, -1, keepdims=True))
                    cos_similarity = np.mean(np.sum(gt * out, -1))
                    result = cos_similarity

                    #                     self.metric_results[name] += cos_similarity
                    #                     metric_type = opt_.pop('type')
                    #                     result = getattr(metric_module,
                    #                                      metric_type)(result_img, gt_img, **opt_)
                    self.metric_results[folder][int(frame_idx),
                                                metric_idx] += result
#                     psnr = getattr(metric_module, metric_type)(result_img, gt_img, **opt_)
#                     with open('/home/wei/exp/EDVR/psnr_log/psnr_first.txt','a+') as f:
#                         f.write(f'{img_name} {psnr}\r\n')

# progress bar
            if rank == 0:
                for _ in range(world_size):
                    pbar.update(f'Test {folder} - '
                                f'{int(frame_idx) + world_size}/{max_idx}')

        if with_metrics:
            if self.opt['dist']:
                # collect data among GPUs
                for _, tensor in self.metric_results.items():
                    dist.reduce(tensor, 0)
                dist.barrier()
            else:
                pass  # assume use one gpu in non-dist testing

            if rank == 0:
                self._log_validation_metric_values(current_iter, dataset_name,
                                                   tb_logger)
Example #30
0
    def validate(self,
                 do_mirroring: bool = True,
                 use_sliding_window: bool = True,
                 step_size: float = 0.5,
                 save_softmax: bool = True,
                 use_gaussian: bool = True,
                 overwrite: bool = True,
                 validation_folder_name: str = 'validation_raw',
                 debug: bool = False,
                 all_in_gpu: bool = False,
                 segmentation_export_kwargs: dict = None,
                 run_postprocessing_on_folds: bool = True):
        if isinstance(self.network, DDP):
            net = self.network.module
        else:
            net = self.network
        ds = net.do_ds
        net.do_ds = False

        current_mode = self.network.training
        self.network.eval()

        assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
        if self.dataset_val is None:
            self.load_dataset()
            self.do_split()

        if segmentation_export_kwargs is None:
            if 'segmentation_export_params' in self.plans.keys():
                force_separate_z = self.plans['segmentation_export_params'][
                    'force_separate_z']
                interpolation_order = self.plans['segmentation_export_params'][
                    'interpolation_order']
                interpolation_order_z = self.plans[
                    'segmentation_export_params']['interpolation_order_z']
            else:
                force_separate_z = None
                interpolation_order = 1
                interpolation_order_z = 0
        else:
            force_separate_z = segmentation_export_kwargs['force_separate_z']
            interpolation_order = segmentation_export_kwargs[
                'interpolation_order']
            interpolation_order_z = segmentation_export_kwargs[
                'interpolation_order_z']

        # predictions as they come from the network go here
        output_folder = join(self.output_folder, validation_folder_name)
        os.makedirs(output_folder, exist_ok=True)
        # this is for debug purposes
        my_input_args = {
            'do_mirroring': do_mirroring,
            'use_sliding_window': use_sliding_window,
            'step_size': step_size,
            'save_softmax': save_softmax,
            'use_gaussian': use_gaussian,
            'overwrite': overwrite,
            'validation_folder_name': validation_folder_name,
            'debug': debug,
            'all_in_gpu': all_in_gpu,
            'segmentation_export_kwargs': segmentation_export_kwargs,
        }
        save_json(my_input_args, join(output_folder, "validation_args.json"))

        if do_mirroring:
            if not self.data_aug_params['do_mirror']:
                raise RuntimeError(
                    "We did not train with mirroring so you cannot do inference with mirroring enabled"
                )
            mirror_axes = self.data_aug_params['mirror_axes']
        else:
            mirror_axes = ()

        pred_gt_tuples = []

        export_pool = Pool(default_num_threads)
        results = []

        all_keys = list(self.dataset_val.keys())
        my_keys = all_keys[self.local_rank::dist.get_world_size()]
        # we cannot simply iterate over all_keys because we need to know pred_gt_tuples and valid_labels of all cases
        # for evaluation (which is done by local rank 0)
        for k in my_keys:
            properties = load_pickle(self.dataset[k]['properties_file'])
            fname = os.path.basename(properties['list_of_data_files'][0])[:-12]
            pred_gt_tuples.append([
                join(output_folder, fname + ".nii.gz"),
                join(self.gt_niftis_folder, fname + ".nii.gz")
            ])
            if k in my_keys:
                if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
                        (save_softmax and not isfile(join(output_folder, fname + ".npz"))):
                    data = np.load(self.dataset[k]['data_file'])['data']

                    print(k, data.shape)
                    data[-1][data[-1] == -1] = 0

                    softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(
                        data[:-1],
                        do_mirroring=do_mirroring,
                        mirror_axes=mirror_axes,
                        use_sliding_window=use_sliding_window,
                        step_size=step_size,
                        use_gaussian=use_gaussian,
                        all_in_gpu=all_in_gpu,
                        mixed_precision=self.fp16)[1]

                    softmax_pred = softmax_pred.transpose(
                        [0] + [i + 1 for i in self.transpose_backward])

                    if save_softmax:
                        softmax_fname = join(output_folder, fname + ".npz")
                    else:
                        softmax_fname = None
                    """There is a problem with python process communication that prevents us from communicating obejcts
                    larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
                    communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
                    enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
                    patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
                    then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
                    filename or np.ndarray and will handle this automatically"""
                    if np.prod(softmax_pred.shape) > (
                            2e9 / 4 * 0.85):  # *0.85 just to be save
                        np.save(join(output_folder, fname + ".npy"),
                                softmax_pred)
                        softmax_pred = join(output_folder, fname + ".npy")

                    results.append(
                        export_pool.starmap_async(
                            save_segmentation_nifti_from_softmax,
                            ((softmax_pred,
                              join(output_folder,
                                   fname + ".nii.gz"), properties,
                              interpolation_order, self.regions_class_order,
                              None, None, softmax_fname, None,
                              force_separate_z, interpolation_order_z), )))

        _ = [i.get() for i in results]
        self.print_to_log_file("finished prediction")

        distributed.barrier()

        if self.local_rank == 0:
            # evaluate raw predictions
            self.print_to_log_file("evaluation of raw predictions")
            task = os.path.basename(self.dataset_directory)
            job_name = self.experiment_name
            _ = aggregate_scores(pred_gt_tuples,
                                 labels=list(range(self.num_classes)),
                                 json_output_file=join(output_folder,
                                                       "summary.json"),
                                 json_name=job_name + " val tiled %s" %
                                 (str(use_sliding_window)),
                                 json_author="Fabian",
                                 json_task=task,
                                 num_threads=default_num_threads)

            if run_postprocessing_on_folds:
                # in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything
                # except the largest connected component for each class. To see if this improves results, we do this for all
                # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
                # have this applied during inference as well
                self.print_to_log_file("determining postprocessing")
                determine_postprocessing(
                    self.output_folder,
                    self.gt_niftis_folder,
                    validation_folder_name,
                    final_subf_name=validation_folder_name + "_postprocessed",
                    debug=debug)
                # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
                # They are always in that folder, even if no postprocessing as applied!

            # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
            # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
            # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
            # be used later
            gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
            os.makedirs(gt_nifti_folder, exist_ok=True)
            for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
                success = False
                attempts = 0
                e = None
                while not success and attempts < 10:
                    try:
                        shutil.copy(f, gt_nifti_folder)
                        success = True
                    except OSError as e:
                        attempts += 1
                        sleep(1)
                if not success:
                    print("Could not copy gt nifti file %s into folder %s" %
                          (f, gt_nifti_folder))
                    if e is not None:
                        raise e

        self.network.train(current_mode)
        net.do_ds = ds
Example #31
0
 def optimizer_step(self, loss, model, optimizer) -> None:
     """Abstraction over ``optimizer.step()`` step."""
     optimizer.step()
     dist.barrier()
Example #32
0
 def deinit_components(self):
     """Deinits the runs components."""
     dist.barrier()
     self.cleanup_process()
Example #33
0
    def run_pretrain_routine(self, model: LightningModule):
        """Sanity check a few things before starting actual training.

        Args:
            model: The model to run sanity test on.
        """
        ref_model = model
        if self.data_parallel:
            ref_model = model.module

        # give model convenience properties
        ref_model.trainer = self

        # set local properties on the model
        self.copy_trainer_model_properties(ref_model)

        # log hyper-parameters
        if self.logger is not None:
            # save exp to get started
            if hasattr(ref_model, "hparams"):
                self.logger.log_hyperparams(ref_model.hparams)

            self.logger.save()

        if self.use_ddp or self.use_ddp2:
            dist.barrier()

        # wait for all models to restore weights
        if self.on_tpu and XLA_AVAILABLE:
            # wait for all processes to catch up
            torch_xla.core.xla_model.rendezvous(
                "pl.Trainer.run_pretrain_routine")

        # set up checkpoint callback
        self.configure_checkpoint_callback()

        # register auto-resubmit when on SLURM
        self.register_slurm_signal_handlers()

        # print model summary
        if self.proc_rank == 0 and self.weights_summary is not None:
            if self.weights_summary in ['full', 'top']:
                ref_model.summarize(mode=self.weights_summary)
            else:
                m = "weights_summary can be None, 'full' or 'top'"
                raise MisconfigurationException(m)

        # track model now.
        # if cluster resets state, the model will update with the saved weights
        self.model = model

        # restore training and model before hpc call
        self.restore_weights(model)

        # download the data and do whatever transforms we need
        self.call_prepare_data(ref_model)

        # when testing requested only run test and return
        if self.testing:
            # only load test dataloader for testing
            self.reset_test_dataloader(ref_model)
            self.run_evaluation(test_mode=True)
            return

        # check if we should run validation during training
        self.disable_validation = not self.is_overriden(
            'validation_step') and not self.fast_dev_run

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        ref_model.on_sanity_check_start()
        if not self.disable_validation and self.num_sanity_val_steps > 0:
            self.reset_val_dataloader(ref_model)
            # init progress bars for validation sanity check
            pbar = tqdm(desc='Validation sanity check',
                        total=self.num_sanity_val_steps *
                        len(self.val_dataloaders),
                        leave=False,
                        position=2 * self.process_position,
                        disable=not self.show_progress_bar,
                        dynamic_ncols=True)
            self.main_progress_bar = pbar
            # dummy validation progress bar
            self.val_progress_bar = tqdm(disable=True)

            eval_results = self.evaluate(model, self.val_dataloaders,
                                         self.num_sanity_val_steps, False)
            _, _, _, callback_metrics, _ = self.process_output(eval_results)

            # close progress bars
            self.main_progress_bar.close()
            self.val_progress_bar.close()

            if self.enable_early_stop:
                self.early_stop_callback.check_metrics(callback_metrics)

        # init progress bar
        pbar = tqdm(leave=True,
                    position=2 * self.process_position,
                    disable=not self.show_progress_bar,
                    dynamic_ncols=True,
                    file=sys.stdout)
        self.main_progress_bar = pbar

        # clear cache before training
        if self.on_gpu:
            torch.cuda.empty_cache()

        # CORE TRAINING LOOP
        self.train()