예제 #1
0
def main():
    args = add_learner_params()
    if args.seed != -1:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
    args.root = 'logs/'+args.name+'/'

    if args.neptune:
        import neptune
        project = "arighosh/pretrain_noisy_label"
        neptune.init(project_qualified_name=project,
                     api_token=os.environ["NEPTUNE_API_TOKEN"])
        neptune.create_experiment(
            name=args.name, send_hardware_metrics=False, params=vars(args))
    fmt = {
        'train_time': '.3f',
        'val_time': '.3f',
        'train_epoch': '.1f',
        'lr': '.1e',
    }
    logger = Logger('logs', base=args.root, fmt=fmt)
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    if args.cuda:
        assert device.type == 'cuda', 'no gpu found!'

    with open(args.root+'config.yml', 'w') as outfile:
        yaml.dump(vars(args), outfile, default_flow_style=False)

    # create model
    model = models.REGISTERED_MODELS[args.problem](args, device=device)
    cur_iter = 0
    # Data loading code
    model.prepare_data()

    continue_training = cur_iter < args.iters
    data_time, it_time = 0, 0
    best_acc = 0.
    best_valid_acc, best_acc_with_valid = 0, 0

    while continue_training:
        train_loader, test_loader, valid_loader, meta_loader = model.dataloaders(
            iters=args.iters)
        train_logs = []
        model.train()
        start_time = time.time()
        for _, batch in enumerate(train_loader):
            cur_iter += 1
            batch = [x.to(device) for x in batch]
            data_time += time.time() - start_time
            logs = {}
            if args.problem not in {'finetune'}:
                meta_batch = next(iter(meta_loader))
                meta_batch = [x.to(device) for x in meta_batch]
                logs = model.train_step(batch, meta_batch, cur_iter)
            else:
                logs = model.train_step(batch, cur_iter)

            # save logs for the batch
            train_logs.append({k: utils.tonp(v) for k, v in logs.items()})
            if cur_iter % args.eval_freq == 0 or cur_iter >= args.iters:
                test_start_time = time.time()
                test_logs, valid_logs = [], []
                model.eval()
                with torch.no_grad():
                    for batch in test_loader:
                        batch = [x.to(device) for x in batch]
                        logs = model.test_step(batch)
                        test_logs.append(logs)
                    for batch in valid_loader:
                        batch = [x.to(device) for x in batch]
                        logs = model.test_step(batch)
                        valid_logs.append(logs)
                model.train()
                test_logs = utils.agg_all_metrics(test_logs)
                valid_logs = utils.agg_all_metrics(valid_logs)
                best_acc = max(best_acc, float(test_logs['acc']))
                test_logs['best_acc'] = best_acc
                if float(valid_logs['acc']) > best_valid_acc:
                    best_valid_acc = float(valid_logs['acc'])
                    best_acc_with_valid = float(test_logs['acc'])
                test_logs['best_acc_with_valid'] = best_acc_with_valid
                #

                if args.neptune:
                    for k, v in test_logs.items():
                        neptune.log_metric('test_'+k, float(v))
                    for k, v in valid_logs.items():
                        neptune.log_metric('valid_'+k, float(v))
                    test_it_time = time.time()-test_start_time
                    neptune.log_metric('test_it_time', test_it_time)
                    neptune.log_metric('test_cur_iter', cur_iter)
                logger.add_logs(cur_iter, test_logs, pref='test_')
            it_time += time.time() - start_time
            if (cur_iter % args.log_freq == 0 or cur_iter >= args.iters):
                train_logs = utils.agg_all_metrics(train_logs)
                if args.neptune:
                    for k, v in train_logs.items():
                        neptune.log_metric('train_'+k, float(v))
                    neptune.log_metric('train_it_time', it_time)
                    neptune.log_metric('train_data_time', data_time)
                    neptune.log_metric(
                        'train_lr', model.optimizer.param_groups[0]['lr'])
                    neptune.log_metric('train_cur_iter', cur_iter)
                logger.add_logs(cur_iter, train_logs, pref='train_')
                logger.add_scalar(
                    cur_iter, 'lr', model.optimizer.param_groups[0]['lr'])
                logger.add_scalar(cur_iter, 'data_time', data_time)
                logger.add_scalar(cur_iter, 'it_time', it_time)
                logger.iter_info()
                logger.save()
                data_time, it_time = 0, 0
                train_logs = []
            if cur_iter >= args.iters:
                continue_training = False
                break
            start_time = time.time()
예제 #2
0
def main_worker(gpu, ngpus, args):
    fmt = {
        'train_time': '.3f',
        'val_time': '.3f',
        'lr': '.1e',
    }
    logger = Logger('logs', base=args.root, fmt=fmt)

    args.gpu = gpu
    torch.cuda.set_device(gpu)
    args.rank = args.node_rank * ngpus + gpu

    device = torch.device('cuda:%d' % args.gpu)

    if args.dist == 'ddp':
        dist.init_process_group(
            backend='nccl',
            init_method='tcp://%s' % args.dist_address,
            world_size=args.world_size,
            rank=args.rank,
        )

        n_gpus_total = dist.get_world_size()
        assert args.batch_size % n_gpus_total == 0
        args.batch_size //= n_gpus_total
        if args.rank == 0:
            print(
                f'===> {n_gpus_total} GPUs total; batch_size={args.batch_size} per GPU'
            )

        print(
            f'===> Proc {dist.get_rank()}/{dist.get_world_size()}@{socket.gethostname()}',
            flush=True)

    # create model
    model = models.REGISTERED_MODELS[args.problem](args, device=device)

    if args.ckpt != '':
        ckpt = torch.load(args.ckpt, map_location=device)
        model.load_state_dict(ckpt['state_dict'])

    # Data loading code
    model.prepare_data()
    train_loader, val_loader = model.dataloaders(iters=args.iters)

    # define optimizer
    cur_iter = 0
    optimizer, scheduler = models.ssl.configure_optimizers(
        args, model, cur_iter - 1)

    # optionally resume from a checkpoint
    if args.ckpt and not args.eval_only:
        optimizer.load_state_dict(ckpt['opt_state_dict'])

    cudnn.benchmark = True

    continue_training = args.iters != 0
    data_time, it_time = 0, 0

    while continue_training:
        train_logs = []
        model.train()

        start_time = time.time()
        for _, batch in enumerate(train_loader):
            cur_iter += 1

            batch = [x.to(device) for x in batch]
            data_time += time.time() - start_time

            logs = {}
            if not args.eval_only:
                # forward pass and compute loss
                logs = model.train_step(batch, cur_iter)
                loss = logs['loss']

                # gradient step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # save logs for the batch
            train_logs.append({k: utils.tonp(v) for k, v in logs.items()})

            if cur_iter % args.save_freq == 0 and args.rank == 0:
                save_checkpoint(args.root, model, optimizer, cur_iter)

            if cur_iter % args.eval_freq == 0 or cur_iter >= args.iters:
                # TODO: aggregate metrics over all processes
                test_logs = []
                model.eval()
                with torch.no_grad():
                    for batch in val_loader:
                        batch = [x.to(device) for x in batch]
                        # forward pass
                        logs = model.test_step(batch)
                        # save logs for the batch
                        test_logs.append(logs)
                model.train()

                test_logs = utils.agg_all_metrics(test_logs)
                logger.add_logs(cur_iter, test_logs, pref='test_')

            it_time += time.time() - start_time

            if (cur_iter % args.log_freq == 0
                    or cur_iter >= args.iters) and args.rank == 0:
                save_checkpoint(args.root, model, optimizer)
                train_logs = utils.agg_all_metrics(train_logs)

                logger.add_logs(cur_iter, train_logs, pref='train_')
                logger.add_scalar(cur_iter, 'lr',
                                  optimizer.param_groups[0]['lr'])
                logger.add_scalar(cur_iter, 'data_time', data_time)
                logger.add_scalar(cur_iter, 'it_time', it_time)
                logger.iter_info()
                logger.save()

                data_time, it_time = 0, 0
                train_logs = []

            if scheduler is not None:
                scheduler.step()

            if cur_iter >= args.iters:
                continue_training = False
                break

            start_time = time.time()

    save_checkpoint(args.root, model, optimizer)

    if args.dist == 'ddp':
        dist.destroy_process_group()