Exemple #1
0
    def __call__(self, model: nn.Module, epoch_idx, output_dir, eval_rtn: dict,
                 test_rtn: dict, logger: logging.Logger,
                 writer: SummaryWriter):
        # save model
        acc = eval_rtn.get('err_spk', 0) - eval_rtn.get('err_sph', 1)
        is_best = acc > self.best_accu
        self.best_accu = acc if is_best else self.best_accu
        model_filename = "epoch_{}.pth".format(epoch_idx)
        save_checkpoint(model,
                        os.path.join(output_dir, model_filename),
                        meta={'epoch': epoch_idx})
        os.system("ln -sf {} {}".format(
            os.path.abspath(os.path.join(output_dir, model_filename)),
            os.path.join(output_dir, "latest.pth")))
        if is_best:
            os.system("ln -sf {} {}".format(
                os.path.abspath(os.path.join(output_dir, model_filename)),
                os.path.join(output_dir, "best.pth")))

        if logger is not None:
            logger.info("EvalHook: best accu: {:.3f}, is_best: {}".format(
                self.best_accu, is_best))
def main():
    args = arguments()

    num_templates = 25  # aka the number of clusters

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    img_transforms = transforms.Compose([transforms.ToTensor(), normalize])
    train_loader, _ = get_dataloader(args.traindata,
                                     args,
                                     num_templates,
                                     img_transforms=img_transforms)

    model = DetectionModel(num_objects=1, num_templates=num_templates)
    loss_fn = DetectionCriterion(num_templates)

    # directory where we'll store model weights
    weights_dir = "weights"
    if not osp.exists(weights_dir):
        os.mkdir(weights_dir)

    # check for CUDA
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    optimizer = optim.SGD(model.learnable_parameters(args.lr),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    # optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        # Set the start epoch if it has not been
        if not args.start_epoch:
            args.start_epoch = checkpoint['epoch']

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=20,
                                          last_epoch=args.start_epoch - 1)

    # train and evalute for `epochs`
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()
        trainer.train(model,
                      loss_fn,
                      optimizer,
                      train_loader,
                      epoch,
                      device=device)

        if (epoch + 1) % args.save_every == 0:
            trainer.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'batch_size': train_loader.batch_size,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                filename="checkpoint_{0}.pth".format(epoch + 1),
                save_path=weights_dir)
def train(train_loader, model, optimizer, train_vars, control_vars, verbose=True):
    curr_epoch_iter = 1
    for batch_idx, (data, target) in enumerate(train_loader):
        control_vars['batch_idx'] = batch_idx
        if batch_idx < control_vars['iter_size']:
            print_verbose("\rPerforming first iteration; current mini-batch: " +
                  str(batch_idx+1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True)
        # check if arrived at iter to start
        if control_vars['curr_epoch_iter'] < control_vars['start_iter_mod']:
            if batch_idx % control_vars['iter_size'] == 0:
                print_verbose("\rGoing through iterations to arrive at last one saved... " +
                      str(int(control_vars['curr_epoch_iter']*100.0/control_vars['start_iter_mod'])) + "% of " +
                      str(control_vars['start_iter_mod']) + " iterations (" +
                      str(control_vars['curr_epoch_iter']) + "/" + str(control_vars['start_iter_mod']) + ")",
                              verbose, n_tabs=0, erase_line=True)
                control_vars['curr_epoch_iter'] += 1
                control_vars['curr_iter'] += 1
                curr_epoch_iter += 1
            continue
        # save checkpoint after final iteration
        if control_vars['curr_iter'] == control_vars['num_iter']:
            print_verbose("\nReached final number of iterations: " + str(control_vars['num_iter']), verbose)
            print_verbose("\tSaving final model checkpoint...", verbose)
            final_model_dict = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'control_vars': control_vars,
                'train_vars': train_vars,
            }
            trainer.save_checkpoint(final_model_dict,
                            filename=train_vars['checkpoint_filenamebase'] +
                                     'final' + str(control_vars['num_iter']) + '.pth.tar')
            control_vars['done_training'] = True
            break
        # start time counter
        start = time.time()
        # get data and targetas cuda variables
        target_heatmaps, target_joints, _, target_prior = target
        data, target_heatmaps, target_prior = Variable(data), Variable(target_heatmaps), Variable(target_prior)
        if train_vars['use_cuda']:
            data = data.cuda()
            target_heatmaps = target_heatmaps.cuda()
            target_prior = target_prior.cuda()
        # visualize if debugging
        # get model output
        output = model(data)
        # accumulate loss for sub-mini-batch
        if train_vars['cross_entropy']:
            loss_func = my_losses.cross_entropy_loss_p_logq
        else:
            loss_func = my_losses.euclidean_loss
        loss, loss_prior = my_losses.calculate_loss_HALNet_prior(loss_func,
            output, target_heatmaps, target_prior, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1,
            model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3,
            model.WEIGHT_LOSS_MAIN, control_vars['iter_size'])
        loss.backward()
        train_vars['total_loss'] += loss
        train_vars['total_loss_prior'] += loss_prior
        # accumulate pixel dist loss for sub-mini-batch
        train_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple(
            train_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size'])
        if train_vars['cross_entropy']:
            train_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple(
                train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size'])
        else:
            train_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs)
        # get boolean variable stating whether a mini-batch has been completed
        minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0
        if minibatch_completed:
            # optimise for mini-batch
            optimizer.step()
            # clear optimiser
            optimizer.zero_grad()
            # append total loss
            train_vars['losses'].append(train_vars['total_loss'].data[0])
            # erase total loss
            total_loss = train_vars['total_loss'].data[0]
            train_vars['total_loss'] = 0
            # append total loss prior
            train_vars['losses_prior'].append(train_vars['total_loss_prior'].data[0])
            # erase total loss
            total_loss_prior = train_vars['total_loss_prior'].data[0]
            train_vars['total_loss_prior'] = 0
            # append dist loss
            train_vars['pixel_losses'].append(train_vars['total_pixel_loss'])
            # erase pixel dist loss
            train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs)
            # append dist loss of sample from output
            train_vars['pixel_losses_sample'].append(train_vars['total_pixel_loss_sample'])
            # erase dist loss of sample from output
            train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs)
            # check if loss is better
            if train_vars['losses'][-1] < train_vars['best_loss']:
                train_vars['best_loss'] = train_vars['losses'][-1]
                print_verbose("  This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose)
                train_vars['best_model_dict'] = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'control_vars': control_vars,
                    'train_vars': train_vars,
                }
            if train_vars['losses_prior'][-1] < train_vars['best_loss_prior']:
                train_vars['best_loss_prior'] = train_vars['losses_prior'][-1]
            # log checkpoint
            if control_vars['curr_iter'] % control_vars['log_interval'] == 0:
                trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, control_vars)
                msg = ''
                msg += print_verbose(
                    "-------------------------------------------------------------------------------------------",
                    verbose) + "\n"
                msg += print_verbose("Current loss (prior): " + str(total_loss_prior), verbose) + "\n"
                msg += print_verbose("Best loss (prior): " + str(train_vars['best_loss_prior']), verbose) + "\n"
                msg += print_verbose("Mean total loss (prior): " + str(np.mean(train_vars['losses_prior'])), verbose) + "\n"
                msg += print_verbose("Mean loss (prior) for last " + str(control_vars['log_interval']) +
                                     " iterations (average total loss): " + str(
                    np.mean(train_vars['losses_prior'][-control_vars['log_interval']:])), verbose) + "\n"
                msg += print_verbose(
                    "-------------------------------------------------------------------------------------------",
                    verbose) + "\n"
                if not control_vars['output_filepath'] == '':
                    with open(control_vars['output_filepath'], 'a') as f:
                        f.write(msg + '\n')

            if control_vars['curr_iter'] % control_vars['log_interval_valid'] == 0:
                print_verbose("\nSaving model and checkpoint model for validation", verbose)
                checkpoint_model_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'control_vars': control_vars,
                    'train_vars': train_vars,
                }
                trainer.save_checkpoint(checkpoint_model_dict,
                                        filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' +
                                                 str(control_vars['curr_iter']) + '.pth.tar')

            # print time lapse
            prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\
                     str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\
                     '(' + str(control_vars['iter_size']) + ')' + '/' +\
                     str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\
                     '(' + str(control_vars['batch_size']) + ')' +\
                     ' - log every ' + str(control_vars['log_interval']) + ' iter): '
            control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start,
                                                            control_vars['curr_iter'], control_vars['num_iter'],
                                                            prefix=prefix)

            control_vars['curr_iter'] += 1
            control_vars['start_iter'] = control_vars['curr_iter'] + 1
            control_vars['curr_epoch_iter'] += 1


    return train_vars, control_vars
def main():
    args = arguments()
    segmentation = False

    # check for CUDA
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    train_loader = get_dataloader(args.traindata, args, device=device)

    model = CoattentionNet()
    loss_fn = SiameseCriterion(device=device)

    pretrained_dict = torch.load(
        "../crowd-counting-revise/weight/checkpoint_104.pth")["model"]
    model_dict = model.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if "frontend" not in k
        and "backend2" not in k and "main_classifier" not in k
    }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    # directory where we'll store model weights
    weights_dir = "weight_all"
    if not osp.exists(weights_dir):
        os.mkdir(weights_dir)

    optimizer = optim.Adam(model.learnable_parameters(args.lr),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    #optimizer = optim.Adam(model.learnable_parameters(args.lr), lr=args.lr)

    if args.resume:
        checkpoint = torch.load(args.resume)
        model = model.to(device)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        # Set the start epoch if it has not been
        if not args.start_epoch:
            args.start_epoch = checkpoint['epoch']

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=20,
                                          last_epoch=args.start_epoch - 1)
    print("Start training!")

    # train and evalute for `epochs`
    best_mae = sys.maxsize
    for epoch in range(args.start_epoch, args.epochs):
        if epoch % 4 == 0 and epoch != 0:
            val_loader = val_dataloader(args, device=device)
            with torch.no_grad():
                if not segmentation:
                    mae, mse = evaluate_model(model,
                                              val_loader,
                                              device=device,
                                              training=True,
                                              debug=args.debug,
                                              segmentation=segmentation)
                    if mae < best_mae:
                        best_mae = mae
                        best_mse = mse
                        best_model = "checkpoint_{0}.pth".format(epoch)
                    log_text = 'epoch: %4d, mae: %4.2f, mse: %4.2f' % (
                        epoch - 1, mae, mse)
                    log_print(log_text, color='green', attrs=['bold'])
                    log_text = 'BEST MAE: %0.1f, BEST MSE: %0.1f, BEST MODEL: %s' % (
                        best_mae, best_mse, best_model)
                    log_print(log_text, color='green', attrs=['bold'])
                else:
                    _, _ = evaluate_model(model,
                                          val_loader,
                                          device=device,
                                          training=True,
                                          debug=args.debug,
                                          segmentation=segmentation)
        scheduler.step()
        trainer.train(model,
                      loss_fn,
                      optimizer,
                      train_loader,
                      epoch,
                      device=device)

        if (epoch + 1) % args.save_every == 0:
            trainer.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'batch_size': train_loader.batch_size,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                filename="checkpoint_{0}.pth".format(epoch + 1),
                save_path=weights_dir)
def validate(valid_loader, model, optimizer, valid_vars, control_vars, verbose=True):
    curr_epoch_iter = 1
    for batch_idx, (data, target) in enumerate(valid_loader):
        control_vars['batch_idx'] = batch_idx
        if batch_idx < control_vars['iter_size']:
            print_verbose("\rPerforming first iteration; current mini-batch: " +
                          str(batch_idx + 1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True)
        # start time counter
        start = time.time()
        # get data and targetas cuda variables
        target_heatmaps, target_joints, target_joints_z = target
        data, target_heatmaps = Variable(data), Variable(target_heatmaps)
        if valid_vars['use_cuda']:
            data = data.cuda()
            target_heatmaps = target_heatmaps.cuda()
        # visualize if debugging
        # get model output
        output = model(data)
        # accumulate loss for sub-mini-batch
        if valid_vars['cross_entropy']:
            loss_func = my_losses.cross_entropy_loss_p_logq
        else:
            loss_func = my_losses.euclidean_loss
        loss = my_losses.calculate_loss_HALNet(loss_func,
            output, target_heatmaps, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1,
            model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3,
            model.WEIGHT_LOSS_MAIN, control_vars['iter_size'])

        if DEBUG_VISUALLY:
            for i in range(control_vars['max_mem_batch']):
                filenamebase_idx = (batch_idx * control_vars['max_mem_batch']) + i
                filenamebase = valid_loader.dataset.get_filenamebase(filenamebase_idx)
                fig = visualize.create_fig()
                #visualize.plot_joints_from_heatmaps(output[3][i].data.numpy(), fig=fig,
                #                                    title=filenamebase, data=data[i].data.numpy())
                #visualize.plot_image_and_heatmap(output[3][i][8].data.numpy(),
                #                                 data=data[i].data.numpy(),
                #                                 title=filenamebase)
                #visualize.savefig('/home/paulo/' + filenamebase.replace('/', '_') + '_heatmap')

                labels_colorspace = conv.heatmaps_to_joints_colorspace(output[3][i].data.numpy())
                data_crop, crop_coords, labels_heatmaps, labels_colorspace = \
                    converter.crop_image_get_labels(data[i].data.numpy(), labels_colorspace, range(21))
                visualize.plot_image(data_crop, title=filenamebase, fig=fig)
                visualize.plot_joints_from_colorspace(labels_colorspace, title=filenamebase, fig=fig, data=data_crop)
                #visualize.savefig('/home/paulo/' + filenamebase.replace('/', '_') + '_crop')
                visualize.show()

        #loss.backward()
        valid_vars['total_loss'] += loss
        # accumulate pixel dist loss for sub-mini-batch
        valid_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple(
            valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size'])
        if valid_vars['cross_entropy']:
            valid_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple(
                valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size'])
        else:
            valid_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs)
        # get boolean variable stating whether a mini-batch has been completed
        minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0
        if minibatch_completed:
            # append total loss
            valid_vars['losses'].append(valid_vars['total_loss'].item())
            # erase total loss
            total_loss = valid_vars['total_loss'].item()
            valid_vars['total_loss'] = 0
            # append dist loss
            valid_vars['pixel_losses'].append(valid_vars['total_pixel_loss'])
            # erase pixel dist loss
            valid_vars['total_pixel_loss'] = [0] * len(model.joint_ixs)
            # append dist loss of sample from output
            valid_vars['pixel_losses_sample'].append(valid_vars['total_pixel_loss_sample'])
            # erase dist loss of sample from output
            valid_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs)
            # check if loss is better
            if valid_vars['losses'][-1] < valid_vars['best_loss']:
                valid_vars['best_loss'] = valid_vars['losses'][-1]
                #print_verbose("  This is a best loss found so far: " + str(valid_vars['losses'][-1]), verbose)
            # log checkpoint
            if control_vars['curr_iter'] % control_vars['log_interval'] == 0:
                trainer.print_log_info(model, optimizer, 1, total_loss, valid_vars, control_vars)
                model_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'control_vars': control_vars,
                    'train_vars': valid_vars,
                }
                trainer.save_checkpoint(model_dict,
                                        filename=valid_vars['checkpoint_filenamebase'] +
                                                 str(control_vars['num_iter']) + '.pth.tar')
            # print time lapse
            prefix = 'Validating (Epoch #' + str(1) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\
                     str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\
                     '(' + str(control_vars['iter_size']) + ')' + '/' +\
                     str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\
                     '(' + str(control_vars['batch_size']) + ')' +\
                     ' - log every ' + str(control_vars['log_interval']) + ' iter): '
            control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start,
                                                            control_vars['curr_iter'], control_vars['num_iter'],
                                                            prefix=prefix)

            control_vars['curr_iter'] += 1
            control_vars['start_iter'] = control_vars['curr_iter'] + 1
            control_vars['curr_epoch_iter'] += 1


    return valid_vars, control_vars
Exemple #6
0
def train():
    parser = argparse.ArgumentParser(
        description='Parameters for training Model')
    # configuration fiule
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt)

    set_logger.setup_logger(opt['logger']['name'],
                            opt['logger']['path'],
                            screen=opt['logger']['screen'],
                            tofile=opt['logger']['tofile'])
    logger = logging.getLogger(opt['logger']['name'])
    day_time = datetime.date.today().strftime('%y%m%d')

    # build model
    model = opt['model']['MODEL']
    logger.info("Building the model of {}".format(model))
    # Extraction and Suppression model
    if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction' or opt['model'][
            'MODEL'] == 'DPRNN_Speaker_Suppression':
        net = model_function.Extractin_Suppression_Model(
            **opt['Dual_Path_Aux_Speaker'])
    # Separation model
    if opt['model']['MODEL'] == 'DPRNN_Speech_Separation':
        net = model_function.Speech_Serapation_Model(
            **opt['Dual_Path_Aux_Speaker'])
    if opt['train']['gpuid']:
        if len(opt['train']['gpuid']) > 1:
            logger.info('We use GPUs : {}'.format(opt['train']['gpuid']))
        else:
            logger.info('We use GPUs : {}'.format(opt['train']['gpuid']))

        device = torch.device('cuda:{}'.format(opt['train']['gpuid'][0]))
        gpuids = opt['train']['gpuid']
        if len(gpuids) > 1:
            net = torch.nn.DataParallel(net, device_ids=gpuids)
        net = net.to(device)
    logger.info('Loading {} parameters: {:.3f} Mb'.format(
        model, check_parameters(net)))

    # build optimizer
    logger.info("Building the optimizer of {}".format(model))
    Optimizer = make_optimizer(net.parameters(), opt)

    Scheduler = ReduceLROnPlateau(Optimizer,
                                  mode='min',
                                  factor=opt['scheduler']['factor'],
                                  patience=opt['scheduler']['patience'],
                                  verbose=True,
                                  min_lr=opt['scheduler']['min_lr'])

    # build dataloader
    logger.info('Building the dataloader of {}'.format(model))
    train_dataloader, val_dataloader = make_dataloader(opt)
    logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
        len(train_dataloader), len(val_dataloader)))

    # build trainer
    logger.info('............. Training ................')

    total_epoch = opt['train']['epoch']
    num_spks = opt['num_spks']
    print_freq = opt['logger']['print_freq']
    checkpoint_path = opt['train']['path']
    early_stop = opt['train']['early_stop']
    max_norm = opt['optim']['clip_norm']
    best_loss = np.inf
    no_improve = 0
    ce_loss = torch.nn.CrossEntropyLoss()
    weight = 0.1

    epoch = 0
    # Resume training settings
    if opt['resume']['state']:
        opt['resume']['path'] = opt['resume'][
            'path'] + '/' + '200722_epoch{}.pth.tar'.format(
                opt['resume']['epoch'])
        ckp = torch.load(opt['resume']['path'], map_location='cpu')
        epoch = ckp['epoch']
        logger.info("Resume from checkpoint {}: epoch {:.3f}".format(
            opt['resume']['path'], epoch))
        net.load_state_dict(ckp['model_state_dict'])
        net.to(device)
        Optimizer.load_state_dict(ckp['optim_state_dict'])

    while epoch < total_epoch:

        epoch += 1
        logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
            epoch, 0))
        num_steps = len(train_dataloader)

        # trainning process
        total_SNRloss = 0.0
        total_CEloss = 0.0
        num_index = 1
        start_time = time.time()
        for inputs, targets in train_dataloader:
            # Separation train
            if opt['model']['MODEL'] == 'DPRNN_Speech_Separation':
                mix = inputs
                ref = targets
                net.train()

                mix = mix.to(device)
                ref = [ref[i].to(device) for i in range(num_spks)]

                net.zero_grad()
                train_out = net(mix)
                SNR_loss = Loss(train_out, ref)
                loss = SNR_loss

            # Extraction train
            if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction':
                mix, aux = inputs
                ref, aux_len, sp_label = targets
                net.train()

                mix = mix.to(device)
                aux = aux.to(device)
                ref = ref.to(device)
                aux_len = aux_len.to(device)
                sp_label = sp_label.to(device)

                net.zero_grad()
                train_out = net([mix, aux, aux_len])
                SNR_loss = Loss_SI_SDR(train_out[0], ref)
                CE_loss = torch.mean(ce_loss(train_out[1], sp_label))
                loss = SNR_loss + weight * CE_loss
                total_CEloss += CE_loss.item()

            # Suppression train
            if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression':
                mix, aux = inputs
                ref, aux_len = targets
                net.train()

                mix = mix.to(device)
                aux = aux.to(device)
                ref = ref.to(device)
                aux_len = aux_len.to(device)

                net.zero_grad()
                train_out = net([mix, aux, aux_len])
                SNR_loss = Loss_SI_SDR(train_out[0], ref)
                loss = SNR_loss

            # BP processs
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm)
            Optimizer.step()

            total_SNRloss += SNR_loss.item()

            if num_index % print_freq == 0:
                message = '<Training epoch:{:d} / {:d} , iter:{:d} / {:d}, lr:{:.3e}, SI-SNR_loss:{:.3f}, CE loss:{:.3f}>'.format(
                    epoch, total_epoch, num_index, num_steps,
                    Optimizer.param_groups[0]['lr'], total_SNRloss / num_index,
                    total_CEloss / num_index)
                logger.info(message)

            num_index += 1

        end_time = time.time()
        mean_SNRLoss = total_SNRloss / num_index
        mean_CELoss = total_CEloss / num_index

        message = 'Finished Training *** <epoch:{:d} / {:d}, iter:{:d}, lr:{:.3e}, ' \
                  'SNR loss:{:.3f}, CE loss:{:.3f}, Total time:{:.3f} min> '.format(
            epoch, total_epoch, num_index, Optimizer.param_groups[0]['lr'], mean_SNRLoss, mean_CELoss, (end_time - start_time) / 60)
        logger.info(message)

        # development processs
        val_num_index = 1
        val_total_loss = 0.0
        val_CE_loss = 0.0
        val_acc_total = 0.0
        val_acc = 0.0
        val_start_time = time.time()
        val_num_steps = len(val_dataloader)
        for inputs, targets in val_dataloader:
            net.eval()
            with torch.no_grad():
                # Separation development
                if opt['model']['MODEL'] == 'DPRNN_Speech_Separation':
                    mix = inputs
                    ref = targets
                    mix = mix.to(device)
                    ref = [ref[i].to(device) for i in range(num_spks)]
                    Optimizer.zero_grad()
                    val_out = net(mix)
                    val_loss = Loss(val_out, ref)
                    val_total_loss += val_loss.item()

                # Extraction development
                if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction':
                    mix, aux = inputs
                    ref, aux_len, label = targets
                    mix = mix.to(device)
                    aux = aux.to(device)
                    ref = ref.to(device)
                    aux_len = aux_len.to(device)
                    label = label.to(device)
                    Optimizer.zero_grad()
                    val_out = net([mix, aux, aux_len])
                    val_loss = Loss_SI_SDR(val_out[0], ref)
                    val_ce = torch.mean(ce_loss(val_out[1], label))
                    val_acc = accuracy_speaker(val_out[1], label)
                    val_acc_total += val_acc
                    val_total_loss += val_loss.item()
                    val_CE_loss += val_ce.item()

                # suppression development
                if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression':
                    mix, aux = inputs
                    ref, aux_len = targets
                    mix = mix.to(device)
                    aux = aux.to(device)
                    ref = ref.to(device)
                    aux_len = aux_len.to(device)
                    Optimizer.zero_grad()
                    val_out = net([mix, aux, aux_len])
                    val_loss = Loss_SI_SDR(val_out[0], ref)
                    val_total_loss += val_loss.item()

                if val_num_index % print_freq == 0:
                    message = '<Valid-Epoch:{:d} / {:d}, iter:{:d} / {:d}, lr:{:.3e}, ' \
                              'val_SISNR_loss:{:.3f}, val_CE_loss:{:.3f}, val_acc :{:.3f}>' .format(
                        epoch, total_epoch, val_num_index, val_num_steps, Optimizer.param_groups[0]['lr'],
                        val_total_loss / val_num_index,
                        val_CE_loss / val_num_index,
                        val_acc_total / val_num_index)
                    logger.info(message)
            val_num_index += 1

        val_end_time = time.time()
        mean_val_total_loss = val_total_loss / val_num_index
        mean_val_CE_loss = val_CE_loss / val_num_index
        mean_acc = val_acc_total / val_num_index
        message = 'Finished *** <epoch:{:d}, iter:{:d}, lr:{:.3e}, val SI-SNR loss:{:.3f}, val_CE_loss:{:.3f}, val_acc:{:.3f}' \
                  ' Total time:{:.3f} min> '.format(epoch, val_num_index, Optimizer.param_groups[0]['lr'],
                                                    mean_val_total_loss, mean_val_CE_loss, mean_acc,
                                                    (val_end_time - val_start_time) / 60)
        logger.info(message)

        Scheduler.step(mean_val_total_loss)

        if mean_val_total_loss >= best_loss:
            no_improve += 1
            logger.info(
                'No improvement, Best SI-SNR Loss: {:.4f}'.format(best_loss))

        if mean_val_total_loss < best_loss:
            best_loss = mean_val_total_loss
            no_improve = 0
            save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time)
            logger.info(
                'Epoch: {:d}, Now Best SI-SNR Loss Change: {:.4f}'.format(
                    epoch, best_loss))

        if no_improve == early_stop:
            save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time)
            logger.info("Stop training cause no impr for {:d} epochs".format(
                no_improve))
            break
Exemple #7
0
    # Hyperparams that have been found to work well
    tconf = trainer.TrainerConfig(max_epochs=650,
                                  batch_size=128,
                                  learning_rate=6e-3,
                                  lr_decay=True,
                                  warmup_tokens=512 * 20,
                                  final_tokens=200 * len(pretrain_dataset) *
                                  block_size,
                                  num_workers=4,
                                  ckpt_path=args.writing_params_path)

    # Initiate trainer, train, then save params of model
    trainer = trainer.Trainer(model, pretrain_dataset, None, tconf)
    trainer.train()
    trainer.save_checkpoint()

elif args.function == 'finetune':
    assert args.writing_params_path is not None
    assert args.finetune_corpus_path is not None
    # - Given:
    #     1. A finetuning corpus specified in args.finetune_corpus_path
    #     2. A path args.reading_params_path containing pretrained model
    #         parameters, or None if finetuning without a pretrained model
    #     3. An output path args.writing_params_path for the model parameters
    # - Goals:
    #     1. If args.reading_params_path is specified, load these parameters
    #         into the model
    #     2. Finetune the model on this corpus
    #     3. Save the resulting model in args.writing_params_path
def train(train_loader, model, optimizer, train_vars):
    verbose = train_vars['verbose']
    for batch_idx, (data, target) in enumerate(train_loader):
        train_vars['batch_idx'] = batch_idx
        # print info about performing first iter
        if batch_idx < train_vars['iter_size']:
            print_verbose(
                "\rPerforming first iteration; current mini-batch: " +
                str(batch_idx + 1) + "/" + str(train_vars['iter_size']),
                verbose,
                n_tabs=0,
                erase_line=True)
        # check if arrived at iter to start
        arrived_curr_iter, train_vars = run_until_curr_iter(
            batch_idx, train_vars)
        if not arrived_curr_iter:
            continue
        # save checkpoint after final iteration
        if train_vars['curr_iter'] - 1 == train_vars['num_iter']:
            train_vars = trainer.save_final_checkpoint(train_vars, model,
                                                       optimizer)
            break
        # start time counter
        start = time.time()
        # get data and target as torch Variables
        _, target_joints, target_heatmaps, target_joints_z = target
        # make target joints be relative
        target_joints = target_joints[:, 3:]
        data, target_heatmaps = Variable(data), Variable(target_heatmaps)
        if train_vars['use_cuda']:
            data = data.cuda()
            target_heatmaps = target_heatmaps.cuda()
            target_joints = target_joints.cuda()
            target_joints_z = target_joints_z.cuda()
        # get model output
        output = model(data)

        # accumulate loss for sub-mini-batch
        if train_vars['cross_entropy']:
            loss_func = my_losses.cross_entropy_loss_p_logq
        else:
            loss_func = my_losses.euclidean_loss
        weights_heatmaps_loss, weights_joints_loss = get_loss_weights(
            train_vars['curr_iter'])
        loss, loss_heatmaps, loss_joints = my_losses.calculate_loss_JORNet(
            loss_func, output, target_heatmaps, target_joints,
            train_vars['joint_ixs'], weights_heatmaps_loss,
            weights_joints_loss, train_vars['iter_size'])
        loss.backward()
        train_vars['total_loss'] += loss.item()
        train_vars['total_joints_loss'] += loss_joints.item()
        train_vars['total_heatmaps_loss'] += loss_heatmaps.item()
        # accumulate pixel dist loss for sub-mini-batch
        train_vars[
            'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple(
                train_vars['total_pixel_loss'], output[3], target_heatmaps,
                train_vars['batch_size'])
        if train_vars['cross_entropy']:
            train_vars[
                'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple(
                    train_vars['total_pixel_loss_sample'], output[3],
                    target_heatmaps, train_vars['batch_size'])
        else:
            train_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs)
        '''
        For debugging training
        for i in range(train_vars['max_mem_batch']):
            filenamebase_idx = (batch_idx * train_vars['max_mem_batch']) + i
            filenamebase = train_loader.dataset.get_filenamebase(filenamebase_idx)
            visualize.plot_joints_from_heatmaps(target_heatmaps[i].data.cpu().numpy(),
                                                title='GT joints: ' + filenamebase, data=data[i].data.cpu().numpy())
            visualize.plot_joints_from_heatmaps(output[3][i].data.cpu().numpy(),
                                                title='Pred joints: ' + filenamebase, data=data[i].data.cpu().numpy())
            visualize.plot_image_and_heatmap(output[3][i][4].data.numpy(),
                                             data=data[i].data.numpy(),
                                             title='Thumb tib heatmap: ' + filenamebase)
            visualize.show()
        '''

        # get boolean variable stating whether a mini-batch has been completed
        minibatch_completed = (batch_idx + 1) % train_vars['iter_size'] == 0
        if minibatch_completed:
            # visualize
            # ax, fig = visualize.plot_3D_joints(target_joints[0])
            # visualize.plot_3D_joints(target_joints[1], ax=ax, fig=fig)
            if train_vars['curr_iter'] % train_vars['log_interval'] == 0:
                fig, ax = visualize.plot_3D_joints(target_joints[0])
                visualize.savefig('joints_GT_' + str(train_vars['curr_iter']) +
                                  '.png')
                #visualize.plot_3D_joints(target_joints[1], fig=fig, ax=ax, color_root='C7')
                #visualize.plot_3D_joints(output[7].data.cpu().numpy()[0], fig=fig, ax=ax, color_root='C7')
                visualize.plot_3D_joints(output[7].data.cpu().numpy()[0])
                visualize.savefig('joints_model_' +
                                  str(train_vars['curr_iter']) + '.png')
                #visualize.show()
                #visualize.savefig('joints_' + str(train_vars['curr_iter']) + '.png')
            # change learning rate to 0.01 after 45000 iterations
            optimizer = change_learning_rate(optimizer, 0.01,
                                             train_vars['curr_iter'])
            # optimise for mini-batch
            optimizer.step()
            # clear optimiser
            optimizer.zero_grad()
            # append total loss
            train_vars['losses'].append(train_vars['total_loss'])
            # erase total loss
            total_loss = train_vars['total_loss']
            train_vars['total_loss'] = 0
            # append total joints loss
            train_vars['losses_joints'].append(train_vars['total_joints_loss'])
            # erase total joints loss
            train_vars['total_joints_loss'] = 0
            # append total joints loss
            train_vars['losses_heatmaps'].append(
                train_vars['total_heatmaps_loss'])
            # erase total joints loss
            train_vars['total_heatmaps_loss'] = 0
            # append dist loss
            train_vars['pixel_losses'].append(train_vars['total_pixel_loss'])
            # erase pixel dist loss
            train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs)
            # append dist loss of sample from output
            train_vars['pixel_losses_sample'].append(
                train_vars['total_pixel_loss_sample'])
            # erase dist loss of sample from output
            train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs)
            # check if loss is better
            if train_vars['losses'][-1] < train_vars['best_loss']:
                train_vars['best_loss'] = train_vars['losses'][-1]
                print_verbose(
                    "  This is a best loss found so far: " +
                    str(train_vars['losses'][-1]), verbose)
                train_vars['best_model_dict'] = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_vars': train_vars
                }
            # log checkpoint
            if train_vars['curr_iter'] % train_vars['log_interval'] == 0:
                trainer.print_log_info(model, optimizer, epoch, total_loss,
                                       train_vars, train_vars)
                aa1 = target_joints[0].data.cpu().numpy()
                aa2 = output[7][0].data.cpu().numpy()
                output_joint_loss = np.sum(np.abs(aa1 - aa2)) / 63
                msg = ''
                msg += print_verbose(
                    "-------------------------------------------------------------------------------------------",
                    verbose) + "\n"
                msg += print_verbose(
                    '\tJoint Coord Avg Loss for first image of current mini-batch: '
                    + str(output_joint_loss) + '\n', train_vars['verbose'])
                msg += print_verbose(
                    "-------------------------------------------------------------------------------------------",
                    verbose) + "\n"
                if not train_vars['output_filepath'] == '':
                    with open(train_vars['output_filepath'], 'a') as f:
                        f.write(msg + '\n')
            if train_vars['curr_iter'] % train_vars['log_interval_valid'] == 0:
                print_verbose(
                    "\nSaving model and checkpoint model for validation",
                    verbose)
                checkpoint_model_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_vars': train_vars,
                }
                trainer.save_checkpoint(
                    checkpoint_model_dict,
                    filename=train_vars['checkpoint_filenamebase'] +
                    'for_valid_' + str(train_vars['curr_iter']) + '.pth.tar')

            # print time lapse
            prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(train_vars['curr_epoch_iter']) + '/' +\
                     str(train_vars['tot_iter']) + ')' + ', (Batch ' + str(train_vars['batch_idx']+1) +\
                     '(' + str(train_vars['iter_size']) + ')' + '/' +\
                     str(train_vars['num_batches']) + ')' + ', (Iter #' + str(train_vars['curr_iter']) +\
                     '(' + str(train_vars['batch_size']) + ')' +\
                     ' - log every ' + str(train_vars['log_interval']) + ' iter): '
            train_vars['tot_toc'] = display_est_time_loop(
                train_vars['tot_toc'] + time.time() - start,
                train_vars['curr_iter'],
                train_vars['num_iter'],
                prefix=prefix)

            train_vars['curr_iter'] += 1
            train_vars['start_iter'] = train_vars['curr_iter'] + 1
            train_vars['curr_epoch_iter'] += 1
    return train_vars
Exemple #9
0
def validate(valid_loader,
             model,
             optimizer,
             valid_vars,
             control_vars,
             verbose=True):
    curr_epoch_iter = 1
    for batch_idx, (data, target) in enumerate(valid_loader):
        control_vars['batch_idx'] = batch_idx
        if batch_idx < control_vars['iter_size']:
            print_verbose(
                "\rPerforming first iteration; current mini-batch: " +
                str(batch_idx + 1) + "/" + str(control_vars['iter_size']),
                verbose,
                n_tabs=0,
                erase_line=True)
        # start time counter
        start = time.time()
        # get data and targetas cuda variables
        target_heatmaps, target_joints, target_handroot = target
        # make target joints be relative
        target_joints = target_joints[:, 3:]
        data, target_heatmaps = Variable(data), Variable(target_heatmaps)
        if valid_vars['use_cuda']:
            data = data.cuda()
            target_joints = target_joints.cuda()
            target_heatmaps = target_heatmaps.cuda()
            target_handroot = target_handroot.cuda()
        # visualize if debugging
        # get model output
        output = model(data)
        # accumulate loss for sub-mini-batch
        if model.cross_entropy:
            loss_func = my_losses.cross_entropy_loss_p_logq
        else:
            loss_func = my_losses.euclidean_loss
        weights_heatmaps_loss, weights_joints_loss = get_loss_weights(
            control_vars['curr_iter'])
        loss, loss_heatmaps, loss_joints = my_losses.calculate_loss_JORNet(
            loss_func, output, target_heatmaps, target_joints,
            valid_vars['joint_ixs'], weights_heatmaps_loss,
            weights_joints_loss, control_vars['iter_size'])
        valid_vars['total_loss'] += loss
        valid_vars['total_joints_loss'] += loss_joints
        valid_vars['total_heatmaps_loss'] += loss_heatmaps
        # accumulate pixel dist loss for sub-mini-batch
        valid_vars[
            'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple(
                valid_vars['total_pixel_loss'], output[3], target_heatmaps,
                control_vars['batch_size'])
        valid_vars[
            'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple(
                valid_vars['total_pixel_loss_sample'], output[3],
                target_heatmaps, control_vars['batch_size'])
        # get boolean variable stating whether a mini-batch has been completed

        for i in range(control_vars['max_mem_batch']):
            filenamebase_idx = (batch_idx * control_vars['max_mem_batch']) + i
            filenamebase = valid_loader.dataset.get_filenamebase(
                filenamebase_idx)

            print('')
            print(filenamebase)

            visualize.plot_image(data[i].data.numpy())
            visualize.show()

            output_batch_numpy = output[7][i].data.cpu().numpy()
            print('\n-------------------------------')
            reshaped_out = output_batch_numpy.reshape((20, 3))
            for j in range(20):
                print('[{}, {}, {}],'.format(reshaped_out[j, 0],
                                             reshaped_out[j, 1],
                                             reshaped_out[j, 2]))
            print('-------------------------------')
            fig, ax = visualize.plot_3D_joints(target_joints[i])
            visualize.plot_3D_joints(output_batch_numpy,
                                     fig=fig,
                                     ax=ax,
                                     color='C6')

            visualize.title(filenamebase)
            visualize.show()

            temp = np.zeros((21, 3))
            output_batch_numpy_abs = output_batch_numpy.reshape((20, 3))
            temp[1:, :] = output_batch_numpy_abs
            output_batch_numpy_abs = temp
            output_joints_colorspace = camera.joints_depth2color(
                output_batch_numpy_abs,
                depth_intr_matrix=synthhands_handler.DEPTH_INTR_MTX,
                handroot=target_handroot[i].data.cpu().numpy())
            visualize.plot_3D_joints(output_joints_colorspace)
            visualize.show()
            aa1 = target_joints[i].data.cpu().numpy().reshape((20, 3))
            aa2 = output[7][i].data.cpu().numpy().reshape((20, 3))
            print('\n----------------------------------')
            print(np.sum(np.abs(aa1 - aa2)) / 60)
            print('----------------------------------')

        #loss.backward()
        valid_vars['total_loss'] += loss
        valid_vars['total_joints_loss'] += loss_joints
        valid_vars['total_heatmaps_loss'] += loss_heatmaps
        # accumulate pixel dist loss for sub-mini-batch
        valid_vars[
            'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple(
                valid_vars['total_pixel_loss'], output[3], target_heatmaps,
                control_vars['batch_size'])
        valid_vars[
            'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple(
                valid_vars['total_pixel_loss_sample'], output[3],
                target_heatmaps, control_vars['batch_size'])
        # get boolean variable stating whether a mini-batch has been completed
        minibatch_completed = (batch_idx + 1) % control_vars['iter_size'] == 0
        if minibatch_completed:
            # append total loss
            valid_vars['losses'].append(valid_vars['total_loss'].data[0])
            # erase total loss
            total_loss = valid_vars['total_loss'].data[0]
            valid_vars['total_loss'] = 0
            # append total joints loss
            valid_vars['losses_joints'].append(
                valid_vars['total_joints_loss'].data[0])
            # erase total joints loss
            valid_vars['total_joints_loss'] = 0
            # append total joints loss
            valid_vars['losses_heatmaps'].append(
                valid_vars['total_heatmaps_loss'].data[0])
            # erase total joints loss
            valid_vars['total_heatmaps_loss'] = 0
            # append dist loss
            valid_vars['pixel_losses'].append(valid_vars['total_pixel_loss'])
            # erase pixel dist loss
            valid_vars['total_pixel_loss'] = [0] * len(model.joint_ixs)
            # append dist loss of sample from output
            valid_vars['pixel_losses_sample'].append(
                valid_vars['total_pixel_loss_sample'])
            # erase dist loss of sample from output
            valid_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs)
            # check if loss is better
            #if valid_vars['losses'][-1] < valid_vars['best_loss']:
            #    valid_vars['best_loss'] = valid_vars['losses'][-1]
            #    print_verbose("  This is a best loss found so far: " + str(valid_vars['losses'][-1]), verbose)
            # log checkpoint
            if control_vars['curr_iter'] % control_vars['log_interval'] == 0:
                trainer.print_log_info(model, optimizer, 1, total_loss,
                                       valid_vars, control_vars)
                model_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'control_vars': control_vars,
                    'train_vars': valid_vars,
                }
                trainer.save_checkpoint(
                    model_dict,
                    filename=valid_vars['checkpoint_filenamebase'] +
                    str(control_vars['num_iter']) + '.pth.tar')
            # print time lapse
            prefix = 'Validating (Epoch #' + str(1) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\
                     str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\
                     '(' + str(control_vars['iter_size']) + ')' + '/' +\
                     str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\
                     '(' + str(control_vars['batch_size']) + ')' +\
                     ' - log every ' + str(control_vars['log_interval']) + ' iter): '
            control_vars['tot_toc'] = display_est_time_loop(
                control_vars['tot_toc'] + time.time() - start,
                control_vars['curr_iter'],
                control_vars['num_iter'],
                prefix=prefix)

            control_vars['curr_iter'] += 1
            control_vars['start_iter'] = control_vars['curr_iter'] + 1
            control_vars['curr_epoch_iter'] += 1

    return valid_vars, control_vars
Exemple #10
0
def train(train_loader, model, optimizer, train_vars):
    verbose = train_vars['verbose']
    for batch_idx, (data, target) in enumerate(train_loader):
        train_vars['batch_idx'] = batch_idx
        # print info about performing first iter
        if batch_idx < train_vars['iter_size']:
            print_verbose("\rPerforming first iteration; current mini-batch: " +
                  str(batch_idx+1) + "/" + str(train_vars['iter_size']), verbose, n_tabs=0, erase_line=True)
        # check if arrived at iter to start
        arrived_curr_iter, train_vars = run_until_curr_iter(batch_idx, train_vars)
        if not arrived_curr_iter:
            continue
        # save checkpoint after final iteration
        if train_vars['curr_iter'] - 1 == train_vars['num_iter']:
            train_vars = save_final_checkpoint(train_vars, model, optimizer)
            break
        # start time counter
        start = time.time()
        # get data and target as torch Variables
        _, target_joints, target_heatmaps, target_joints_z = target
        data, target_heatmaps = Variable(data), Variable(target_heatmaps)
        if train_vars['use_cuda']:
            data = data.cuda()
            target_heatmaps = target_heatmaps.cuda()
        # get model output
        output = model(data)
        # accumulate loss for sub-mini-batch
        if model.cross_entropy:
            loss_func = my_losses.cross_entropy_loss_p_logq
        else:
            loss_func = my_losses.euclidean_loss
        loss = my_losses.calculate_loss_HALNet(loss_func,
            output, target_heatmaps, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1,
            model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3,
            model.WEIGHT_LOSS_MAIN, train_vars['iter_size'])
        loss.backward()
        train_vars['total_loss'] += loss
        # accumulate pixel dist loss for sub-mini-batch
        train_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple(
            train_vars['total_pixel_loss'], output[3], target_heatmaps, train_vars['batch_size'])
        if train_vars['cross_entropy']:
            train_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple(
                train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, train_vars['batch_size'])
        else:
            train_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs)
        # get boolean variable stating whether a mini-batch has been completed
        minibatch_completed = (batch_idx+1) % train_vars['iter_size'] == 0
        if minibatch_completed:
            # optimise for mini-batch
            optimizer.step()
            # clear optimiser
            optimizer.zero_grad()
            # append total loss
            train_vars['losses'].append(train_vars['total_loss'].item())
            # erase total loss
            total_loss = train_vars['total_loss'].item()
            train_vars['total_loss'] = 0
            # append dist loss
            train_vars['pixel_losses'].append(train_vars['total_pixel_loss'])
            # erase pixel dist loss
            train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs)
            # append dist loss of sample from output
            train_vars['pixel_losses_sample'].append(train_vars['total_pixel_loss_sample'])
            # erase dist loss of sample from output
            train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs)
            # check if loss is better
            if train_vars['losses'][-1] < train_vars['best_loss']:
                train_vars['best_loss'] = train_vars['losses'][-1]
                print_verbose("  This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose)
                train_vars['best_model_dict'] = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_vars': train_vars
                }
            # log checkpoint
            if train_vars['curr_iter'] % train_vars['log_interval'] == 0:
                trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, train_vars)

            if train_vars['curr_iter'] % train_vars['log_interval_valid'] == 0:
                print_verbose("\nSaving model and checkpoint model for validation", verbose)
                checkpoint_model_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_vars': train_vars,
                }
                trainer.save_checkpoint(checkpoint_model_dict,
                                        filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' +
                                                 str(train_vars['curr_iter']) + '.pth.tar')

            # print time lapse
            prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(train_vars['curr_epoch_iter']) + '/' +\
                     str(train_vars['tot_iter']) + ')' + ', (Batch ' + str(train_vars['batch_idx']+1) +\
                     '(' + str(train_vars['iter_size']) + ')' + '/' +\
                     str(train_vars['num_batches']) + ')' + ', (Iter #' + str(train_vars['curr_iter']) +\
                     '(' + str(train_vars['batch_size']) + ')' +\
                     ' - log every ' + str(train_vars['log_interval']) + ' iter): '
            train_vars['tot_toc'] = display_est_time_loop(train_vars['tot_toc'] + time.time() - start,
                                                            train_vars['curr_iter'], train_vars['num_iter'],
                                                            prefix=prefix)

            train_vars['curr_iter'] += 1
            train_vars['start_iter'] = train_vars['curr_iter'] + 1
            train_vars['curr_epoch_iter'] += 1
    return train_vars
Exemple #11
0
def main_func(args):

    cdf = mc.ConfidenceDepthFrameworkFactory()
    val_loader, _ = df.create_data_loaders(args.data_path
                                           , loader_type='val'
                                           , data_type= args.data_type
                                           , modality= args.data_modality
                                           , num_samples= args.num_samples
                                           , depth_divisor= args.divider
                                           , max_depth= args.max_depth
                                           , max_gt_depth= args.max_gt_depth
                                           , workers= args.workers
                                           , batch_size=1)
    if not args.evaluate:
        train_loader, _ = df.create_data_loaders(args.data_path
                                                 , loader_type='train'
                                                 , data_type=args.data_type
                                                 , modality=args.data_modality
                                                 , num_samples=args.num_samples
                                                 , depth_divisor=args.divider
                                                 , max_depth=args.max_depth
                                                 , max_gt_depth=args.max_gt_depth
                                                 , workers=args.workers
                                                 , batch_size=args.batch_size)

    # evaluation mode
    if args.evaluate:
        cdfmodel,loss, epoch = trainer.resume(args.evaluate,cdf,True)
        output_directory = create_eval_output_folder(args)
        os.makedirs(output_directory)
        print(output_directory)
        save_arguments(args,output_directory)
        trainer.validate(val_loader, cdfmodel, loss, epoch,print_frequency=args.print_freq,num_image_samples=args.val_images, output_folder=output_directory, conf_recall=args.pr,conf_threshold= args.thrs)
        return

    output_directory = create_output_folder(args)
    os.makedirs(output_directory)
    print(output_directory)
    save_arguments(args, output_directory)

    # optionally resume from a checkpoint
    if args.resume:
        cdfmodel, loss, loss_def, best_result_error, optimizer, scheduler = trainer.resume(args.resume,cdf,False)

    # create new model
    else:
        cdfmodel = cdf.create_model(args.dcnet_modality, args.training_mode, args.dcnet_arch, args.dcnet_pretrained, args.confnet_arch, args.confnet_pretrained, args.lossnet_arch, args.lossnet_pretrained)
        cdfmodel, opt_parameters = cdf.to_device(cdfmodel)
        optimizer, scheduler = trainer.create_optimizer(args.optimizer, opt_parameters, args.momentum, args.weight_decay, args.lr, args.lrs, args.lrm)
        loss, loss_definition = cdf.create_loss(args.criterion, ('ln' in args.training_mode), (0.5 if 'dc1' in args.training_mode else 1.0))
        best_result_error = math.inf


    for epoch in range(0, args.epochs):
        trainer.train(train_loader, cdfmodel, loss, optimizer, output_directory, epoch)
        epoch_result = trainer.validate(val_loader, cdfmodel, loss, epoch=epoch,print_frequency=args.print_freq,num_image_samples=args.val_images, output_folder=output_directory)
        scheduler.step(epoch)

        is_best = epoch_result.rmse < best_result_error
        if is_best:
            best_result_error = epoch_result.rmse
            trainer.report_top_result(os.path.join(output_directory, 'best_result.txt'), epoch, epoch_result)
            # if img_merge is not None:
            #     img_filename = output_directory + '/comparison_best.png'
            #     utils.save_image(img_merge, img_filename)

        trainer.save_checkpoint(cdf, cdfmodel, loss_definition, optimizer, scheduler,best_result_error, is_best, epoch,
                                output_directory)
def train(train_loader, model, optimizer, train_vars, control_vars, verbose=True):
    curr_epoch_iter = 1
    for batch_idx, (data, target) in enumerate(train_loader):
        control_vars['batch_idx'] = batch_idx
        if batch_idx < control_vars['iter_size']:
            print_verbose("\rPerforming first iteration; current mini-batch: " +
                  str(batch_idx+1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True)
        # check if arrived at iter to start
        if control_vars['curr_epoch_iter'] < control_vars['start_iter_mod']:
            control_vars['curr_epoch_iter'] = control_vars['start_iter_mod']
            msg = ''
            if batch_idx % control_vars['iter_size'] == 0:
                msg += print_verbose("\rGoing through iterations to arrive at last one saved... " +
                      str(int(control_vars['curr_epoch_iter']*100.0/control_vars['start_iter_mod'])) + "% of " +
                      str(control_vars['start_iter_mod']) + " iterations (" +
                      str(control_vars['curr_epoch_iter']) + "/" + str(control_vars['start_iter_mod']) + ")",
                              verbose, n_tabs=0, erase_line=True)
                control_vars['curr_epoch_iter'] += 1
                control_vars['curr_iter'] += 1
                curr_epoch_iter += 1
            if not control_vars['output_filepath'] == '':
                with open(control_vars['output_filepath'], 'a') as f:
                    f.write(msg + '\n')
            continue
        # save checkpoint after final iteration
        if control_vars['curr_iter'] == control_vars['num_iter']:
            print_verbose("\nReached final number of iterations: " + str(control_vars['num_iter']), verbose)
            print_verbose("\tSaving final model checkpoint...", verbose)
            final_model_dict = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'control_vars': control_vars,
                'train_vars': train_vars,
            }
            trainer.save_checkpoint(final_model_dict,
                            filename=train_vars['checkpoint_filenamebase'] +
                                     'final' + str(control_vars['num_iter']) + '.pth.tar')
            control_vars['done_training'] = True
            break
        # start time counter
        start = time.time()
        # get data and targetas cuda variables
        target_heatmaps, target_joints, target_roothand = target
        data, target_heatmaps, target_joints, target_roothand = Variable(data), Variable(target_heatmaps),\
                                               Variable(target_joints), Variable(target_roothand)
        if train_vars['use_cuda']:
            data = data.cuda()
            target_heatmaps = target_heatmaps.cuda()
            target_joints = target_joints.cuda()
        # get model output
        output = model(data)
        '''
        visualize.plot_joints_from_heatmaps(target_heatmaps[0, :, :, :].cpu().data.numpy(),
                                            title='', data=data[0].cpu().data.numpy())
        visualize.show()
        visualize.plot_image_and_heatmap(target_heatmaps[0][4].cpu().data.numpy(),
                                         data=data[0].cpu().data.numpy(),
                                         title='')
        visualize.show()
        visualize.plot_image_and_heatmap(output[3][0][4].cpu().data.numpy(),
                                         data=data[0].cpu().data.numpy(),
                                         title='')
        visualize.show()
        '''
        # accumulate loss for sub-mini-batch
        if train_vars['cross_entropy']:
            loss_func = my_losses.cross_entropy_loss_p_logq
        else:
            loss_func = my_losses.euclidean_loss
        weights_heatmaps_loss, weights_joints_loss = get_loss_weights(control_vars['curr_iter'])
        loss, loss_heatmaps, loss_joints = my_losses.calculate_loss_JORNet(
            loss_func, output, target_heatmaps, target_joints, train_vars['joint_ixs'],
            weights_heatmaps_loss, weights_joints_loss, control_vars['iter_size'])
        loss.backward()
        train_vars['total_loss'] += loss.data[0]
        train_vars['total_joints_loss'] += loss_joints.data[0]
        train_vars['total_heatmaps_loss'] += loss_heatmaps.data[0]
        # accumulate pixel dist loss for sub-mini-batch
        train_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple(
            train_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size'])
        train_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple(
            train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size'])
        # get boolean variable stating whether a mini-batch has been completed
        minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0
        if minibatch_completed:
            # optimise for mini-batch
            optimizer.step()
            # clear optimiser
            optimizer.zero_grad()
            # append total loss
            train_vars['losses'].append(train_vars['total_loss'])
            # erase total loss
            total_loss = train_vars['total_loss']
            train_vars['total_loss'] = 0
            # append total joints loss
            train_vars['losses_joints'].append(train_vars['total_joints_loss'])
            # erase total joints loss
            train_vars['total_joints_loss'] = 0
            # append total joints loss
            train_vars['losses_heatmaps'].append(train_vars['total_heatmaps_loss'])
            # erase total joints loss
            train_vars['total_heatmaps_loss'] = 0
            # append dist loss
            train_vars['pixel_losses'].append(train_vars['total_pixel_loss'])
            # erase pixel dist loss
            train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs)
            # append dist loss of sample from output
            train_vars['pixel_losses_sample'].append(train_vars['total_pixel_loss_sample'])
            # erase dist loss of sample from output
            train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs)
            # check if loss is better
            if train_vars['losses'][-1] < train_vars['best_loss']:
                train_vars['best_loss'] = train_vars['losses'][-1]
                print_verbose("  This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose)
                train_vars['best_model_dict'] = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'control_vars': control_vars,
                    'train_vars': train_vars,
                }
            if train_vars['losses_joints'][-1] < train_vars['best_loss_joints']:
                train_vars['best_loss_joints'] = train_vars['losses_joints'][-1]
            if train_vars['losses_heatmaps'][-1] < train_vars['best_loss_heatmaps']:
                train_vars['best_loss_heatmaps'] = train_vars['losses_heatmaps'][-1]
            # log checkpoint
            if control_vars['curr_iter'] % control_vars['log_interval'] == 0:
                trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, control_vars)
                aa1 = target_joints[0].data.cpu().numpy()
                aa2 = output[7][0].data.cpu().numpy()
                output_joint_loss = np.sum(np.abs(aa1 - aa2)) / 63
                msg = ''
                msg += print_verbose(
                    "-------------------------------------------------------------------------------------------",
                    verbose) + "\n"
                msg += print_verbose('\tJoint Coord Avg Loss for first image of current mini-batch: ' +
                                     str(output_joint_loss) + '\n', control_vars['verbose'])
                msg += print_verbose(
                    "-------------------------------------------------------------------------------------------",
                    verbose) + "\n"
                if not control_vars['output_filepath'] == '':
                    with open(control_vars['output_filepath'], 'a') as f:
                        f.write(msg + '\n')
            if control_vars['curr_iter'] % control_vars['log_interval_valid'] == 0:
                print_verbose("\nSaving model and checkpoint model for validation", verbose)
                checkpoint_model_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'control_vars': control_vars,
                    'train_vars': train_vars,
                }
                trainer.save_checkpoint(checkpoint_model_dict,
                                        filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' +
                                                 str(control_vars['curr_iter']) + '.pth.tar')



            # print time lapse
            prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\
                     str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\
                     '(' + str(control_vars['iter_size']) + ')' + '/' +\
                     str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\
                     '(' + str(control_vars['batch_size']) + ')' +\
                     ' - log every ' + str(control_vars['log_interval']) + ' iter): '
            control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start,
                                                            control_vars['curr_iter'], control_vars['num_iter'],
                                                            prefix=prefix)

            control_vars['curr_iter'] += 1
            control_vars['start_iter'] = control_vars['curr_iter'] + 1
            control_vars['curr_epoch_iter'] += 1


    return train_vars, control_vars
Exemple #13
0
def main():
    args = arguments.parse()

    checkpoint = args.checkpoint if args.checkpoint else None

    model, params = get_network(args.arch,
                                args.n_attrs,
                                checkpoint=checkpoint,
                                base_frozen=args.freeze_base)

    criterion = get_criterion(loss_type=args.loss, args=args)

    optimizer = get_optimizer(params,
                              fc_lr=float(args.lr),
                              opt_type=args.optimizer_type,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=10,
                                          gamma=0.1,
                                          last_epoch=args.start_epoch - 1)
    if checkpoint:
        state = torch.load(checkpoint)
        model.load_state_dict(state["state_dict"])
        scheduler.load_state_dict(state['scheduler'])

    # Dataloader code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    val_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        normalize,
    ])

    logger.info("Setting up training data")
    train_loader = data.DataLoader(COCOAttributes(
        args.attributes,
        args.train_ann,
        train=True,
        split='train2014',
        transforms=train_transforms,
        dataset_root=args.dataset_root),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers,
                                   pin_memory=True)

    logger.info("Setting up validation data")
    val_loader = data.DataLoader(COCOAttributes(
        args.attributes,
        args.val_ann,
        train=False,
        split='val2014',
        transforms=val_transforms,
        dataset_root=args.dataset_root),
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True)

    best_prec1 = 0

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    logger.info("Beginning training...")

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()

        # train for one epoch
        trainer.train(train_loader, model, criterion, optimizer, epoch,
                      args.print_freq)

        # evaluate on validation set
        # trainer.validate(val_loader, model, criterion, epoch, args.print_freq)
        prec1 = 0

        # remember best prec@1 and save checkpoint
        best_prec1 = max(prec1, best_prec1)
        trainer.save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'loss': args.loss,
                'optimizer': args.optimizer_type,
                'state_dict': model.state_dict(),
                'scheduler': scheduler.state_dict(),
                'batch_size': args.batch_size,
                'best_prec1': best_prec1,
            }, args.save_dir,
            '{0}_{1}_checkpoint.pth.tar'.format(args.arch, args.loss).lower())

    logger.info('Finished Training')

    logger.info('Running evaluation')
    evaluator = evaluation.Evaluator(model,
                                     val_loader,
                                     batch_size=args.batch_size,
                                     name="{0}_{1}".format(
                                         args.arch, args.loss))
    with torch.no_grad():
        evaluator.evaluate()