示例#1
0
                                 smooth_penalty=smooth_penalty, chunk_length=chunk_length)
    multi_shoot.prepare_intermediate( input_tensor )

    # create model
    optimizer = AdaBelief(filter(lambda p: p.requires_grad, multi_shoot.parameters()), lr = lr, eps=1e-16, rectify=False,
                          betas=(0.5, 0.9))
    #optimizer = Adam(filter(lambda p: p.requires_grad, multi_shoot.parameters()), lr=lr, eps=1e-16,
    #                      betas=(0.5, 0.9))

    best_loss = np.inf
    for _epoch in range(N_epoch):
        # adjust learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] *= gamma

        optimizer.zero_grad()

        # forward and optimizer
        prediction_chunks, data_chunks = multi_shoot.fit_and_grad( input_tensor, time_points )
        loss = multi_shoot.get_loss( prediction_chunks, data_chunks )

        loss.backward(retain_graph=False)
        optimizer.step()

        print( 'Epoch {} Alpha {}, Beta {}, Gamma {}, Delta {}'.format(_epoch,dcmfunc.alpha, dcmfunc.beta, dcmfunc.gamma, dcmfunc.delta))

        if loss.item()<best_loss:
            # concatenate by time, and plot
            prediction2, data2 = [], []
            for prediction, data in zip(prediction_chunks, data_chunks):
                if data.shape[0] > 1:
示例#2
0
def memo_valor(env_fn,
                model=MEMO,
                  memo_kwargs=dict(),
                  annealing_kwargs=dict(),
                  seed=0,
                  episodes_per_expert=40,
                  epochs=50,
                  # warmup=10,
                  train_iters=5,
                  step_size=5,
                  memo_lr=1e-3,
                  train_batch_size=50,
                  eval_batch_size=200,
                  max_ep_len=1000,
                  logger_kwargs=dict(),
                  config_name='standard',
                  save_freq=10,
               # replay_buffers=[],
               memories=[]):
    # W&B Logging
    wandb.login()

    composite_name = 'E ' + str(epochs) + ' B ' + str(train_batch_size) + ' ENC ' + \
                     str(memo_kwargs['encoder_hidden']) + 'DEC ' + str(memo_kwargs['decoder_hidden'])

    wandb.init(project="MEMO", group='Epochs: ' + str(epochs),  name=composite_name, config=locals())

    assert memories != [], "No examples found! Replay/memory buffers must be set to proceed."

    # Special function to avoid certain slowdowns from PyTorch + MPI combo.
    setup_pytorch_for_mpi()

    # Set up logger and save configuration
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Instantiate environment
    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Model    # Create discriminator and monitor it
    con_dim = len(memories)
    memo = model(obs_dim=obs_dim[0], out_dim=act_dim[0], **memo_kwargs)

    # Set up model saving
    logger.setup_pytorch_saver([memo])

    # Sync params across processes
    sync_params(memo)
    N_expert = episodes_per_expert*max_ep_len
    print("N Expert: ", N_expert)

    # Buffer
    # local_episodes_per_epoch = int(episodes_per_epoch / num_procs())
    local_iter_per_epoch = int(train_iters / num_procs())

    # Count variables
    var_counts = tuple(count_vars(module) for module in [memo])
    logger.log('\nNumber of parameters: \t d: %d\n' % var_counts)

    # Optimizers
    # memo_optimizer = AdaBelief(memo.parameters(), lr=memo_lr, eps=1e-20, rectify=True)
    memo_optimizer = AdaBelief(memo.parameters(), lr=memo_lr, eps=1e-16, rectify=True)
    # memo_optimizer = Adam(memo.parameters(), lr=memo_lr, betas=(0.9, 0.98), eps=1e-9)

    start_time = time.time()

    # Prepare data
    mem = MemoryBatch(memories, step=step_size)

    # transition_states, pure_states, transition_actions, expert_ids = mem.collate()
    transition_states, pure_states, transition_actions, expert_ids = mem.collate()
    total_l_old, recon_l_old, context_l_old = 0, 0, 0

    # Main Loop
    kl_beta_schedule = frange_cycle_sigmoid(epochs, **annealing_kwargs)

    for epoch in range(epochs):
        memo.train()

        # Select state transitions and actions at random indexes
        batch_indexes = torch.randint(len(transition_states), (train_batch_size,))

        raw_states_batch, delta_states_batch, actions_batch, sampled_experts = \
           pure_states[batch_indexes], transition_states[batch_indexes], transition_actions[batch_indexes], expert_ids[batch_indexes]


        for i in range(local_iter_per_epoch):
            # kl_beta = kl_beta_schedule[epoch]
            kl_beta = 1
            # only take context labeling into account for first label
            loss, recon_loss, X, latent_labels, vq_loss = memo(raw_states_batch, delta_states_batch,  actions_batch,
                                                                     kl_beta)
            memo_optimizer.zero_grad()
            loss.mean().backward()
            mpi_avg_grads(memo)
            memo_optimizer.step()

        # scheduler.step(loss.mean().data.item())

        total_l_new, recon_l_new, vq_l_new = loss.mean().data.item(), recon_loss.mean().data.item(), vq_loss.mean().data.item()

        memo_metrics = {'MEMO Loss': total_l_new, 'Recon Loss': recon_l_new, "VQ Labeling Loss": vq_l_new,
                        "KL Beta": kl_beta_schedule[epoch]}
        wandb.log(memo_metrics)

        logger.store(TotalLoss=total_l_new, PolicyLoss=recon_l_new, # ContextLoss=context_l_new,
                     DeltaTotalLoss=total_l_new-total_l_old, DeltaPolicyLoss=recon_l_new-recon_l_old,
                     )

        total_l_old, recon_l_old = total_l_new, recon_l_new  # , context_l_new

        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, [memo], None)

        # Log
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpochBatchSize', train_batch_size)
        logger.log_tabular('TotalLoss', average_only=True)
        logger.log_tabular('PolicyLoss', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()

    print("Finished training, and detected %d contexts!" % len(memo.found_contexts))
    # wandb.finish()
    print('memo type', memo)
    return memo, mem
示例#3
0
def train_mlp(args):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    assert args.cae_weight, "No trained cae weight"
    cae = CAE().to(device)
    cae.eval()
    cae.load_state_dict(torch.load(args.cae_weight))

    print('a')
    train_dataset = PathDataSet(S2D_data_path, cae.encoder)
    val_dataset = PathDataSet(S2D_data_path, cae.encoder, is_val=True)

    print('b')
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False)

    now = datetime.now()
    output_folder = args.output_folder + '/' + now.strftime(
        '%Y-%m-%d_%H-%M-%S')
    check_and_create_dir(output_folder)

    model = MLP(args.input_size, args.output_size).to(device)
    if args.load_weights:
        print("Load weight from {}".format(args.load_weights))
        model.load_state_dict(torch.load(args.load_weights))

    criterion = nn.MSELoss()
    # optimizer = torch.optim.Adagrad(model.parameters())
    optimizer = AdaBelief(model.parameters(),
                          lr=1e-4,
                          eps=1e-10,
                          betas=(0.9, 0.999),
                          weight_decouple=True,
                          rectify=False)

    for epoch in range(args.max_epoch):
        model.train()

        for i, data in enumerate(tqdm(train_loader)):
            # get data
            input_data = data[0].to(device)  # B, 32
            next_config = data[1].to(device)  # B, 2

            # predict
            predict_config = model(input_data)

            # get loss
            loss = criterion(predict_config, next_config)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()

            neptune.log_metric("batch_loss", loss.item())

        print('\ncalculate validation accuracy..')

        model.eval()
        with torch.no_grad():
            losses = []
            for i, data in enumerate(tqdm(val_loader)):
                # get data
                input_data = data[0].to(device)  # B, 32
                next_config = data[1].to(device)  # B, 2

                # predict
                predict_config = model(input_data)

                # get loss
                loss = criterion(predict_config, next_config)

                losses.append(loss.item())

            val_loss = np.mean(losses)
            neptune.log_metric("val_loss", val_loss)

        print("validation result, epoch {}: {}".format(epoch, val_loss))
        if epoch % 5 == 0:
            torch.save(model.state_dict(),
                       '{}/epoch_{}.tar'.format(output_folder, epoch))
示例#4
0
def train(hyp, opt, device, tb_writer=None, wandb=None):
    logger.info(f'Hyperparameters {hyp}')
    log_dir = Path(tb_writer.log_dir) if tb_writer else Path(
        opt.logdir) / 'evolve'  # logging directory
    wdir = log_dir / 'weights'  # weights directory
    os.makedirs(wdir, exist_ok=True)
    last = wdir / 'last.pt'
    best = wdir / 'best.pt'
    results_file = str(log_dir / 'results.txt')
    epochs, batch_size, total_batch_size, weights, rank = \
        opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

    # Save run settings
    with open(log_dir / 'hyp.yaml', 'w') as f:
        yaml.dump(hyp, f, sort_keys=False)
    with open(log_dir / 'opt.yaml', 'w') as f:
        yaml.dump(vars(opt), f, sort_keys=False)

    # Configure
    cuda = device.type != 'cpu'
    init_seeds(2 + rank)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.FullLoader)  # data dict
    with torch_distributed_zero_first(rank):
        check_dataset(data_dict)  # check
    train_path = data_dict['train']
    test_path = data_dict['val']
    nc, names = (1, ['item']) if opt.single_cls else (int(
        data_dict['nc']), data_dict['names'])  # number classes, names
    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (
        len(names), nc, opt.data)  # check

    # Model
    pretrained = weights.endswith('.pt')
    if pretrained:
        with torch_distributed_zero_first(rank):
            attempt_download(weights)  # download if not found locally
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
        if hyp.get('anchors'):
            ckpt['model'].yaml['anchors'] = round(
                hyp['anchors'])  # force autoanchor
        model = Model(opt.cfg or ckpt['model'].yaml, ch=3,
                      nc=nc).to(device)  # create
        exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [
        ]  # exclude keys
        state_dict = ckpt['model'].float().state_dict()  # to FP32
        state_dict = intersect_dicts(state_dict,
                                     model.state_dict(),
                                     exclude=exclude)  # intersect
        model.load_state_dict(state_dict, strict=False)  # load
        logger.info(
            'Transferred %g/%g items from %s' %
            (len(state_dict), len(model.state_dict()), weights))  # report
    else:
        model = Model(opt.cfg, ch=3, nc=nc).to(device)  # create

    # Freeze
    freeze = []  # parameter names to freeze (full or partial)
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            print('freezing %s' % k)
            v.requires_grad = False

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / total_batch_size),
                     1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= total_batch_size * accumulate / nbs  # scale weight_decay

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_modules():
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
            pg2.append(v.bias)  # biases
        if isinstance(v, nn.BatchNorm2d):
            pg0.append(v.weight)  # no decay
        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
            pg1.append(v.weight)  # apply decay

    if opt.adam:
        optimizer = AdaBelief(model.parameters(),
                              lr=1e-4,
                              eps=1e-16,
                              betas=(0.9, 0.999),
                              weight_decouple=True,
                              rectify=True)
    else:
        optimizer = optim.SGD(pg0,
                              lr=hyp['lr0'],
                              momentum=hyp['momentum'],
                              nesterov=True)

    # optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']})  # add pg1 with weight_decay
    # optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' %
                (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp[
        'lrf']) + hyp['lrf']  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # plot_lr_scheduler(optimizer, scheduler, epochs)

    # Logging
    if wandb and wandb.run is None:
        id = ckpt.get('wandb_id') if 'ckpt' in locals() else None
        wandb_run = wandb.init(config=opt,
                               resume="allow",
                               project="YOLOv5",
                               name=os.path.basename(log_dir),
                               id=id)

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # Results
        if ckpt.get('training_results') is not None:
            with open(results_file, 'w') as file:
                file.write(ckpt['training_results'])  # write results.txt

        # Epochs
        start_epoch = ckpt['epoch'] + 1
        if opt.resume:
            assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (
                weights, epochs)
            shutil.copytree(wdir, wdir.parent /
                            f'weights_backup_epoch{start_epoch - 1}'
                            )  # save previous weights
        if epochs < start_epoch:
            logger.info(
                '%s has been trained for %g epochs. Fine-tuning for %g additional epochs.'
                % (weights, ckpt['epoch'], epochs))
            epochs += ckpt['epoch']  # finetune additional epochs

        del ckpt, state_dict

    # Image sizes
    gs = int(max(model.stride))  # grid size (max stride)
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size
                         ]  # verify imgsz are gs-multiples

    # DP mode
    if cuda and rank == -1 and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # SyncBatchNorm
    if opt.sync_bn and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        logger.info('Using SyncBatchNorm()')

    # Exponential moving average
    ema = ModelEMA(model) if rank in [-1, 0] else None

    # DDP mode
    if cuda and rank != -1:
        model = DDP(model,
                    device_ids=[opt.local_rank],
                    output_device=opt.local_rank)

    # Trainloader
    dataloader, dataset = create_dataloader(train_path,
                                            imgsz,
                                            batch_size,
                                            gs,
                                            opt,
                                            hyp=hyp,
                                            augment=True,
                                            cache=opt.cache_images,
                                            rect=opt.rect,
                                            rank=rank,
                                            world_size=opt.world_size,
                                            workers=opt.workers)
    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
    nb = len(dataloader)  # number of batches
    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (
        mlc, nc, opt.data, nc - 1)

    # Process 0
    if rank in [-1, 0]:
        ema.updates = start_epoch * nb // accumulate  # set EMA updates
        testloader = create_dataloader(test_path,
                                       imgsz_test,
                                       total_batch_size,
                                       gs,
                                       opt,
                                       hyp=hyp,
                                       augment=False,
                                       cache=opt.cache_images
                                       and not opt.notest,
                                       rect=True,
                                       rank=-1,
                                       world_size=opt.world_size,
                                       workers=opt.workers)[0]  # testloader

        if not opt.resume:
            labels = np.concatenate(dataset.labels, 0)
            c = torch.tensor(labels[:, 0])  # classes
            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
            # model._initialize_biases(cf.to(device))
            plot_labels(labels, save_dir=log_dir)
            if tb_writer:
                # tb_writer.add_hparams(hyp, {})  # causes duplicate https://github.com/ultralytics/yolov5/pull/384
                tb_writer.add_histogram('classes', c, 0)

            # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset,
                              model=model,
                              thr=hyp['anchor_t'],
                              imgsz=imgsz)

    # Model parameters
    hyp['cls'] *= nc / 80.  # scale coco-tuned hyp['cls'] to current dataset
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # iou loss ratio (obj_loss = 1.0 or iou)
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(
        device)  # attach class weights
    model.names = names

    # Start training
    t0 = time.time()
    nw = max(round(hyp['warmup_epochs'] * nb),
             1e3)  # number of warmup iterations, max(3 epochs, 1k iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0
               )  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
    scheduler.last_epoch = start_epoch - 1  # do not move
    scaler = amp.GradScaler(enabled=cuda)
    logger.info('Image sizes %g train, %g test\n'
                'Using %g dataloader workers\nLogging results to %s\n'
                'Starting training for %g epochs...' %
                (imgsz, imgsz_test, dataloader.num_workers, log_dir, epochs))
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional)
        if opt.image_weights:
            # Generate indices
            if rank in [-1, 0]:
                cw = model.class_weights.cpu().numpy() * (
                    1 - maps)**2  # class weights
                iw = labels_to_image_weights(dataset.labels,
                                             nc=nc,
                                             class_weights=cw)  # image weights
                dataset.indices = random.choices(
                    range(dataset.n), weights=iw,
                    k=dataset.n)  # rand weighted idx
            # Broadcast if DDP
            if rank != -1:
                indices = (torch.tensor(dataset.indices)
                           if rank == 0 else torch.zeros(dataset.n)).int()
                dist.broadcast(indices, 0)
                if rank != 0:
                    dataset.indices = indices.cpu().numpy()

        # Update mosaic border
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = torch.zeros(4, device=device)  # mean losses
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        logger.info(
            ('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls',
                                   'total', 'targets', 'img_size'))
        if rank in [-1, 0]:
            pbar = tqdm(pbar, total=nb)  # progress bar
        optimizer.zero_grad()
        for i, (
                imgs, targets, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device, non_blocking=True).float(
            ) / 255.0  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, nbs / total_batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(ni, xi, [
                        hyp['warmup_bias_lr'] if j == 2 else 0.0,
                        x['initial_lr'] * lf(epoch)
                    ])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(
                            ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = F.interpolate(imgs,
                                         size=ns,
                                         mode='bilinear',
                                         align_corners=False)

            # Forward
            with amp.autocast(enabled=cuda):
                pred = model(imgs)  # forward
                loss, loss_items = compute_loss(
                    pred, targets.to(device),
                    model)  # loss scaled by batch_size
                if rank != -1:
                    loss *= opt.world_size  # gradient averaged between devices in DDP mode

            # Backward
            scaler.scale(loss).backward()

            # Optimize
            if ni % accumulate == 0:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema:
                    ema.update(model)

            # Print
            if rank in [-1, 0]:
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9
                                 if torch.cuda.is_available() else 0)  # (GB)
                s = ('%10s' * 2 +
                     '%10.4g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem,
                                      *mloss, targets.shape[0], imgs.shape[-1])
                pbar.set_description(s)

                # Plot
                if ni < 3:
                    f = str(log_dir / f'train_batch{ni}.jpg')  # filename
                    result = plot_images(images=imgs,
                                         targets=targets,
                                         paths=paths,
                                         fname=f)
                    # if tb_writer and result is not None:
                    # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
                    # tb_writer.add_graph(model, imgs)  # add model to tensorboard

            # end batch ------------------------------------------------------------------------------------------------

        # Scheduler
        lr = [x['lr'] for x in optimizer.param_groups]  # for tensorboard
        scheduler.step()

        # DDP process 0 or single-GPU
        if rank in [-1, 0]:
            # mAP
            if ema:
                ema.update_attr(
                    model,
                    include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
            final_epoch = epoch + 1 == epochs
            if not opt.notest or final_epoch:  # Calculate mAP
                results, maps, times = test.test(
                    opt.data,
                    batch_size=total_batch_size,
                    imgsz=imgsz_test,
                    model=ema.ema,
                    single_cls=opt.single_cls,
                    dataloader=testloader,
                    save_dir=log_dir,
                    plots=epoch == 0 or final_epoch,  # plot first and last
                    log_imgs=opt.log_imgs)

            # Write
            with open(results_file, 'a') as f:
                f.write(
                    s + '%10.4g' * 7 % results +
                    '\n')  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
            if len(opt.name) and opt.bucket:
                os.system('gsutil cp %s gs://%s/results/results%s.txt' %
                          (results_file, opt.bucket, opt.name))

            # Log
            tags = [
                'train/giou_loss',
                'train/obj_loss',
                'train/cls_loss',  # train loss
                'metrics/precision',
                'metrics/recall',
                'metrics/mAP_0.5',
                'metrics/mAP_0.5:0.95',
                'val/giou_loss',
                'val/obj_loss',
                'val/cls_loss',  # val loss
                'x/lr0',
                'x/lr1',
                'x/lr2'
            ]  # params
            for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
                if tb_writer:
                    tb_writer.add_scalar(tag, x, epoch)  # tensorboard
                if wandb:
                    wandb.log({tag: x})  # W&B

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # weighted combination of [P, R, [email protected], [email protected]]
            if fi > best_fitness:
                best_fitness = fi

            # Save model
            save = (not opt.nosave) or (final_epoch and not opt.evolve)
            if save:
                with open(results_file, 'r') as f:  # create checkpoint
                    ckpt = {
                        'epoch':
                        epoch,
                        'best_fitness':
                        best_fitness,
                        'training_results':
                        f.read(),
                        'model':
                        ema.ema,
                        'optimizer':
                        None if final_epoch else optimizer.state_dict(),
                        'wandb_id':
                        wandb_run.id if wandb else None
                    }

                # Save last, best and delete
                torch.save(ckpt, last)
                if best_fitness == fi:
                    torch.save(ckpt, best)
                del ckpt
        # end epoch ----------------------------------------------------------------------------------------------------
    # end training

    if rank in [-1, 0]:
        # Strip optimizers
        n = opt.name if opt.name.isnumeric() else ''
        fresults, flast, fbest = log_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
        for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file],
                          [flast, fbest, fresults]):
            if os.path.exists(f1):
                os.rename(f1, f2)  # rename
                if str(f2).endswith('.pt'):  # is *.pt
                    strip_optimizer(f2)  # strip optimizer
                    os.system(
                        'gsutil cp %s gs://%s/weights' %
                        (f2, opt.bucket)) if opt.bucket else None  # upload
        # Finish
        if not opt.evolve:
            plot_results(save_dir=log_dir)  # save as results.png
        logger.info('%g epochs completed in %.3f hours.\n' %
                    (epoch - start_epoch + 1, (time.time() - t0) / 3600))

    dist.destroy_process_group() if rank not in [-1, 0] else None
    torch.cuda.empty_cache()
    return results
示例#5
0
def main():
    """Model training."""
    train_speakers, valid_speakers = get_valid_speakers()

    # define transforms for train & validation samples
    train_transform = Compose([Resize(760, 80), ToTensor()])

    # define datasets & loaders
    train_dataset = TrainDataset('train',
                                 train_speakers,
                                 transform=train_transform)
    valid_dataset = TrainDataset('train',
                                 valid_speakers,
                                 transform=train_transform)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False)

    device = get_device()
    print(f'Selected device: {device}')

    model = torch.hub.load('huawei-noah/ghostnet',
                           'ghostnet_1x',
                           pretrained=True)
    model.classifier = nn.Linear(in_features=1280, out_features=1, bias=True)

    net = model
    net.to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = AdaBelief(net.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min',
                                                     factor=0.2,
                                                     patience=3,
                                                     eps=1e-4,
                                                     verbose=True)

    # prepare valid target
    yvalid = get_valid_targets(valid_dataset)

    # training loop
    for epoch in range(10):
        loss_log = {'train': [], 'valid': []}
        train_loss = []

        net.train()
        for x, y in tqdm(train_loader):
            x, y = mixup(x, y, alpha=0.2)
            x, y = x.to(device), y.to(device, dtype=torch.float32)
            optimizer.zero_grad()
            outputs = net(x)

            loss = criterion(outputs, y.unsqueeze(1))
            loss.backward()
            optimizer.step()

            # save loss
            train_loss.append(loss.item())

        # evaluate
        net.eval()
        valid_pred = torch.Tensor([]).to(device)

        for x, y in valid_loader:
            with torch.no_grad():
                x, y = x.to(device), y.to(device, dtype=torch.float32)
                ypred = net(x)
                valid_pred = torch.cat([valid_pred, ypred], 0)

        valid_pred = sigmoid(valid_pred.cpu().numpy())
        val_loss = log_loss(yvalid, valid_pred, eps=1e-7)
        val_acc = (yvalid == (valid_pred > 0.5).astype(int).flatten()).mean()
        tqdm.write(
            f'Epoch {epoch} train_loss={np.mean(train_loss):.4f}; val_loss={val_loss:.4f}; val_acc={val_acc:.4f}'
        )

        loss_log['train'].append(np.mean(train_loss))
        loss_log['valid'].append(val_loss)
        scheduler.step(loss_log['valid'][-1])

    torch.save(net.state_dict(), 'ghostnet_model.pt')
    print('Training is complete.')