Beispiel #1
0
def main_merge():
    global args, best_corr

    args.store_name = '{}_merged'.format(args.model)
    args.store_name = args.store_name + datetime.now().strftime('_%m-%d_%H-%M')
    args.start_epoch = 0

    check_rootfolders(args)

    model = Baseline(args.img_feat_size, args.au_feat_size)

    model = torch.nn.DataParallel(model).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    if args.use_multistep:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.step_milestones, args.step_decay)
    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    if args.resume and os.path.isfile(args.resume):
        print('Load checkpoint:', args.resume)
        ckpt = torch.load(args.resume)
        args.start_epoch = ckpt['epoch']
        best_corr = ckpt['best_corr']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        print('Loaded ckpt at epoch:', args.start_epoch)

    # initialize datasets
    train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=[args.train_csv, args.val_csv],
        vidmap_path=[args.train_vidmap, args.val_vidmap],
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='merge'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               drop_last=True)

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))

    tb_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, model, optimizer, epoch, log_training, tb_writer)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_corr': 0.0,
            }, False)
        if args.use_multistep:
            scheduler.step()
Beispiel #2
0
def main():
    net = Baseline(num_classes=culane.num_classes, deep_base=args['deep_base']).cuda()

    print('load checkpoint \'%s.pth\' for evaluation' % args['checkpoint'])
    pretrained_dict = torch.load(os.path.join(ckpt_path, exp_name, args['checkpoint'] + '_checkpoint.pth'))
    pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()}
    net.load_state_dict(pretrained_dict)

    net.eval()

    save_dir = os.path.join(ckpt_path, exp_name, 'vis_%s_test' % args['checkpoint'])
    check_mkdir(save_dir)
    log_path = os.path.join(save_dir, str(datetime.datetime.now()) + '.log')

    data_list = [l.strip('\n') for l in open(os.path.join(culane.root, culane.list, 'test_gt.txt'), 'r')]

    loss_record = AverageMeter()
    gt_all, prediction_all=[], []

    for idx in range(len(data_list)):
        print('evaluating %d / %d' % (idx + 1, len(data_list)))

        img = Image.open(culane.root + data_list[idx].split(' ')[0]).convert('RGB')
        gt = Image.open(culane.root + data_list[idx].split(' ')[1])

        img, gt = val_joint_transform(img, gt)

        with torch.no_grad():
            img_var = Variable(img_transform(img).unsqueeze(0)).cuda()
            gt_var = Variable(mask_transform(gt).unsqueeze(0)).cuda()

            prediction = net(img_var)[0]

            loss = criterion(prediction, gt_var)
            loss_record.update(loss.data, 1)

            scoremap = F.softmax(prediction, dim=1).data.squeeze().cpu().numpy()

            prediction = prediction.data.max(1)[1].squeeze().cpu().numpy().astype(np.uint8)
            prediction_all.append(prediction)
            gt_all.append(np.array(gt))

        if args['save_results']:
            check_mkdir(save_dir + data_list[idx].split(' ')[0][:-10])
            out_file = open(os.path.join(save_dir, data_list[idx].split(' ')[0][1:-4] + '.lines.txt'), 'w')
            prob2lines(scoremap, out_file)

    acc, acc_cls, mean_iu, fwavacc = evaluation(prediction_all, gt_all, culane.num_classes)
    log = 'val results: loss %.5f  acc %.5f  acc_cls %.5f  mean_iu %.5f  fwavacc %.5f' % \
              (loss_record.avg, acc, acc_cls, mean_iu, fwavacc)
    print(log)
    open(log_path, 'w').write(log + '\n')
Beispiel #3
0
def test():
    # Prepare env
    env = create_env()
    h, w, c = env.observation_space.shape

    # Load 5 best models
    device = torch.device("cpu")
    model_dir = "./policy_grad"
    model_fns = {}
    for fn in os.listdir(model_dir):
        if fn.endswith('.pth'):
            score = fn.split("_")[-1][:-4]
            model_fns[fn] = float(score)
    top_5 = heapq.nlargest(3, model_fns, key=model_fns.get)

    models = []
    for fn in top_5:
        path = os.path.join(model_dir, fn)
        model = Baseline(h, w).to(device)
        model.load_state_dict(torch.load(path, map_location='cpu'))
        model.eval()
        models.append(model)

    # Watch race car perform
    state = env.reset().transpose((2, 0, 1))
    state = torch.tensor([state], dtype=torch.float, device=device)
    total_reward = 0
    for t in count():
        # Select and perform an action
        votes = []
        for model in models:
            pi, _ = model(state)
            votes.append(pi.argmax().item())
        action_idx = Counter(votes).most_common(1)[0][0]
        action = index_to_action(action_idx)
        state, reward, done, _ = env.step(action)
        env.render()

        # Update
        state = state.transpose((2, 0, 1))
        state = torch.tensor([state], dtype=torch.float, device=device)
        total_reward += reward
        if done:
            break
    print("Total reward: {}".format(total_reward))
Beispiel #4
0
def main():
    net = Baseline(num_classes=culane.num_classes,
                   deep_base=args['deep_base']).cuda().train()
    net = DataParallelWithCallback(net)

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['base_lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['base_lr']
    }],
                          momentum=args['momentum'])

    if len(args['checkpoint']) > 0:
        print('training resumes from \'%s\'' % args['checkpoint'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['checkpoint'] + '_checkpoint.pth')))
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['checkpoint'] + '_checkpoint_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['base_lr']
        optimizer.param_groups[1]['lr'] = args['base_lr']

    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')

    train(net, optimizer)
def main_test():
    print('Running test...')
    torch.multiprocessing.set_sharing_strategy('file_system')
    model = Baseline()
    if args.use_swa:
        model = torch.optim.swa_utils.AveragedModel(model)
    model = torch.nn.DataParallel(model).cuda()
    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    if args.resume and os.path.isfile(args.resume):
        print('Load checkpoint:', args.resume)
        ckpt = torch.load(args.resume)
        args.start_epoch = ckpt['epoch']
        best_corr = ckpt['best_corr']
        model.load_state_dict(ckpt['state_dict'])
        print('Loaded ckpt at epoch:', args.start_epoch)
    else:
        print('No model given. Abort!')
        exit(1)

    test_loader = torch.utils.data.DataLoader(
        dataset=EEV_Dataset(
            csv_path=None,
            vidmap_path=args.test_vidmap,
            image_feat_path=args.image_features,
            audio_feat_path=args.audio_features,
            mode='test',
            test_freq=args.test_freq
        ),
        batch_size=None, shuffle=False,
        num_workers=args.workers, pin_memory=False
    )

    model.eval()
    batch_time = AverageMeter()

    t_start = time.time()

    outputs = []
    with torch.no_grad():
        for i, (img_feat, au_feat, frame_count, vid) in enumerate(test_loader):
            img_feat = torch.stack(img_feat).cuda()
            au_feat = torch.stack(au_feat).cuda()
            assert len(au_feat.size()) == 3, 'bad auf %s' % (vid)
            output = model(img_feat, au_feat) # [Clip S 15]
            # rearrange and remove extra padding in the end
            output = rearrange(output, 'Clip S C -> (Clip S) C')
            output = torch.cat([output, output[-1:]]) # repeat the last frame to avoid missing 
            if args.train_freq < args.test_freq:
                # print('interpolating:', output.size()[0], frame_count)
                output = interpolate_output(output, args.train_freq, 6)
            # print('Interpolated:', output.size()[0], frame_count)
            # truncate extra frames
            assert output.size(0) >= frame_count, '{}/{}'.format(output.size(0), frame_count)
            output = output[:frame_count]
            outputs.append((vid, frame_count, output.cpu().detach().numpy()))

            # update statistics
            batch_time.update(time.time() - t_start)
            t_start = time.time()

            if i % args.print_freq == 0:
                output = ('Test: [{0}/{1}]\t'
                          'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
                    i, len(test_loader), batch_time=batch_time))
                print(output)
    
    time_stamps = [0, 166666, 333333, 500000, 666666, 833333]
    time_step = 1000000 # time starts at 0
    header = 'Video ID,Timestamp (milliseconds),amusement,anger,awe,concentration,confusion,contempt,contentment,disappointment,doubt,elation,interest,pain,sadness,surprise,triumph\n'
   
    final_res = {}
    for vid, frame_count, out in outputs:# videos
        video_time = frame_count // 6 + 1
        # print('video', vid, video_time)
        entry_count = 0
        for t in range(video_time): # seconds
            for i in range(6): # frames
                timestamp = time_step * t + time_stamps[i]
                fcc = t * 6 + i
                if fcc >= frame_count:
                    continue
                # print('Frame count', frame_count)
                frame_output = out[fcc]
                frame_output = [str(x) for x in frame_output]
                temp = '{vid},{timestamp},'.format(vid=vid,timestamp=timestamp) + ','.join(frame_output) + '\n'
                # file.write(temp)
                if vid in final_res:
                    final_res[vid].append(temp)
                else:
                    final_res[vid] = [temp]
                entry_count += 1
        assert entry_count == frame_count
    # fixed for now
    missing = [('WKXrnB7alT8', 2919), ('o0ooW14pIa4', 3733), ('GufMoL_MuNE',2038), ('Uee0Tv1rTz8', 1316), ('ScvvOWtb04Q', 152), ('R9kJlLungmo', 3609),('QMW3GuohzzE', 822), ('fjJYTW2n6rk', 4108), ('rbTIMt0VcLw', 1084),('L9cdaj74kLo', 3678), ('l-ka23gU4NA', 1759)]
    for vid, length in missing:
        video_time = length // 6 + 1
        # print('video', vid, video_time)
        for t in range(video_time): # seconds
            for i in range(6): # frames
                timestamp = time_step * t + time_stamps[i]
                fcc = t * 6 + i
                if fcc >= length:
                    continue
                frame_output = ',0'*15
                temp = '{vid},{timestamp}'.format(vid=vid, timestamp=timestamp) + frame_output + '\n'
                # file.write(temp)
                if vid in final_res:
                    final_res[vid].append(temp)
                else:
                    final_res[vid] = [temp]
    print('Write test outputs...')
    with open('test_output.csv', 'w') as file:
        file.write(header)
        temp_vidmap = [x.strip().split(' ') for x in open(args.test_vidmap)]
        temp_vidmap = [x[0] for x in temp_vidmap]
        for vid in tqdm(temp_vidmap):
            for entry in final_res[vid]:
                file.write(entry)
def main_train(config, checkpoint_dir=None):
    global args, best_corr
    best_corr = 0.0

    args.store_name = '{}'.format(args.model)
    args.store_name = args.store_name + datetime.now().strftime('_%m-%d_%H-%M-%S')
    args.start_epoch = 0

    # check_rootfolders(args)
    if args.model == 'Baseline':
        model = Baseline()
    elif args.model == 'TCFPN':
        model = TCFPN(layers=[48, 64, 96], in_channels=(2048 + 128), num_classes=15, kernel_size=11)
    
    model = torch.nn.DataParallel(model).cuda()

    if config['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    elif config['optimizer'] == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
    
    # custom optimizer
    if args.use_sam:
        base_optim = torch.optim.Adam
        optimizer = SAM(model.parameters(), base_optim, lr=config['lr'])
    # custom lr scheduler
    if args.use_cos_wr:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.cos_wr_t0,T_mult=args.cos_wr_t_mult)
    elif args.use_cos:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.cos_t_max)
    # SWA
    if args.use_swa:
        swa_model = torch.optim.swa_utils.AveragedModel(model)
        swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=config['lr'])

    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    # if args.resume and os.path.isfile(args.resume):
    #     print('Load checkpoint:', args.resume)
    #     ckpt = torch.load(args.resume)
    #     args.start_epoch = ckpt['epoch']
    #     best_corr = ckpt['best_corr']
    #     model.load_state_dict(ckpt['state_dict'])
    #     optimizer.load_state_dict(ckpt['optimizer'])
    #     print('Loaded ckpt at epoch:', args.start_epoch)
    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)


    # initialize datasets
    train_loader = torch.utils.data.DataLoader(
        dataset=EEV_Dataset(
            csv_path=args.train_csv,
            vidmap_path=args.train_vidmap,
            image_feat_path=args.image_features,
            audio_feat_path=args.audio_features,
            mode='train', lpfilter=args.lp_filter
        ),
        batch_size=config['batch_size'], shuffle=True,
        num_workers=args.workers, pin_memory=False,
        drop_last=True
    )

    val_loader = torch.utils.data.DataLoader(
        dataset=EEV_Dataset(
            csv_path=args.val_csv,
            vidmap_path=args.val_vidmap,
            image_feat_path=args.image_features,
            audio_feat_path=args.audio_features,
            mode='val'
        ),
        batch_size=None, shuffle=False,
        num_workers=args.workers, pin_memory=False
    )

    accuracy = correlation
    # with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
    #     f.write(str(args))
    
    # tb_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))

    for epoch in range(args.start_epoch, args.epochs):
        # train
        train(train_loader, model, optimizer, epoch, None, None)
        # do lr scheduling after epoch
        if args.use_swa and epoch >= args.swa_start:
            print('swa stepping...')
            swa_model.update_parameters(model)
            swa_scheduler.step()
        elif args.use_cos_wr:
            print('cos warm restart (T0:{} Tm:{}) stepping...'.format(args.cos_wr_t0, args.cos_wr_t_mult))
            scheduler.step()
        elif args.use_cos:
            print('cos (Tmax:{}) stepping...'.format(args.cos_t_max))
            scheduler.step()
        
        # validate
        if args.use_swa and epoch >= args.swa_start:
            # validate use swa model
            corr, loss = validate(val_loader, swa_model, accuracy, epoch, None, None)
        else:
            corr, loss = validate(val_loader, model, accuracy, epoch, None, None)
        is_best = corr > best_corr
        best_corr = max(corr, best_corr)
        # tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch)
        # output_best = 'Best corr: %.4f\n' % (best_corr)
        # print(output_best)
        # save_checkpoint({
        #     'epoch': epoch + 1,
        #     'state_dict': model.state_dict(),
        #     'optimizer': optimizer.state_dict(),
        #     'best_corr': best_corr,
        # }, is_best)
        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            if is_best:
                path = os.path.join(checkpoint_dir, "checkpoint_best")
            torch.save((model.state_dict(), optimizer.state_dict()), path)
        tune.report(loss=loss, accuracy=corr, best_corr=best_corr)
Beispiel #7
0
def main_train():
    global args, best_corr

    args.store_name = '{}'.format(args.model)
    args.store_name = args.store_name + datetime.now().strftime(
        '_%m-%d_%H-%M-%S')
    args.start_epoch = 0

    if not args.val_only:
        check_rootfolders(args)
    if args.model == 'Baseline':
        if args.cls_indices:
            model = Baseline(args.img_feat_size,
                             args.au_feat_size,
                             num_classes=len(args.cls_indices))
        else:
            print('Feature size:', args.img_feat_size, args.au_feat_size)
            model = Baseline(args.img_feat_size, args.au_feat_size)
    elif args.model == 'TCFPN':
        model = TCFPN(layers=[48, 64, 96],
                      in_channels=(128),
                      num_classes=15,
                      kernel_size=11)
    elif args.model == 'BaseAu':
        model = Baseline_Au(args.au_feat_size)
    elif args.model == 'BaseImg':
        model = Baseline_Img(args.img_feat_size)
    elif args.model == 'EmoBase':
        model = EmoBase()

    model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)
    # optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    # custom optimizer
    if args.use_sam:
        base_optim = torch.optim.Adam
        optimizer = SAM(model.parameters(), base_optim, lr=args.learning_rate)
    # custom lr scheduler
    if args.use_cos_wr:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=args.cos_wr_t0, T_mult=args.cos_wr_t_mult)
    elif args.use_cos:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.cos_t_max)
    elif args.use_multistep:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.step_milestones, args.step_decay)
    # SWA
    if args.use_swa:
        swa_model = torch.optim.swa_utils.AveragedModel(model)
        swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
                                                    swa_lr=args.learning_rate)

    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    if args.resume and os.path.isfile(args.resume):
        print('Load checkpoint:', args.resume)
        ckpt = torch.load(args.resume)
        args.start_epoch = ckpt['epoch']
        best_corr = ckpt['best_corr']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        print('Loaded ckpt at epoch:', args.start_epoch)

    # initialize datasets
    train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.train_csv,
        vidmap_path=args.train_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='train',
        lpfilter=args.lp_filter,
        train_freq=args.train_freq,
        val_freq=args.val_freq,
        cls_indices=args.cls_indices),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset(
        csv_path=args.val_csv,
        vidmap_path=args.val_vidmap,
        image_feat_path=args.image_features,
        audio_feat_path=args.audio_features,
        mode='val',
        train_freq=args.train_freq,
        val_freq=args.val_freq,
        cls_indices=args.cls_indices,
        repeat_sample=args.repeat_sample),
                                             batch_size=None,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)

    accuracy = correlation

    if args.val_only:
        print('Run validation ...')
        print('start epoch:', args.start_epoch, 'model:', args.resume)
        validate(val_loader, model, accuracy, args.start_epoch, None, None)
        return

    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))

    tb_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, model, optimizer, epoch, log_training, tb_writer)
        # do lr scheduling after epoch
        if args.use_swa and epoch >= args.swa_start:
            print('swa stepping...')
            swa_model.update_parameters(model)
            swa_scheduler.step()
        elif args.use_cos_wr or args.use_cos or args.use_multistep:
            scheduler.step()

        if (epoch + 1) > 2 and ((epoch + 1) % args.eval_freq == 0 or
                                (epoch + 1) == args.epochs):
            # validate
            if args.use_swa and epoch >= args.swa_start:
                # validate use swa model
                corr = validate(val_loader, swa_model, accuracy, epoch,
                                log_training, tb_writer)
            else:
                corr = validate(val_loader, model, accuracy, epoch,
                                log_training, tb_writer)
            is_best = corr > best_corr
            best_corr = max(corr, best_corr)
            tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch)
            output_best = 'Best corr: %.4f\n' % (best_corr)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_corr': best_corr,
                }, is_best)
Beispiel #8
0
def prepare(args):
    resume_from_checkpoint = args.resume_from_checkpoint

    prepare_start_time = time.time()
    logger.info('global', 'Start preparing.')
    check_config_dir()
    logger.info('setting', config_info(), time_report=False)

    model = Baseline(num_classes=Config.nr_class)
    logger.info('setting', model_summary(model), time_report=False)
    logger.info('setting', str(model), time_report=False)

    train_transforms = transforms.Compose([
        transforms.Resize(Config.input_shape),
        transforms.RandomApply([
            transforms.ColorJitter(
                brightness=0.3, contrast=0.3, saturation=0.3, hue=0)
        ],
                               p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.Pad(10),
        transforms.RandomCrop(Config.input_shape),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    test_transforms = transforms.Compose([
        transforms.Resize(Config.input_shape),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    trainset = Veri776_train(transforms=train_transforms, need_attr=True)
    testset = Veri776_test(transforms=test_transforms, need_attr=True)

    pksampler = PKSampler(trainset, p=Config.P, k=Config.K)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=Config.batch_size,
                                               sampler=pksampler,
                                               num_workers=Config.nr_worker,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
        testset,
        batch_size=Config.batch_size,
        sampler=torch.utils.data.SequentialSampler(testset),
        num_workers=Config.nr_worker,
        pin_memory=True)

    weight_decay_setting = parm_list_with_Wdecay(model)
    optimizer = torch.optim.Adam(weight_decay_setting, lr=Config.lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_multi_func)

    losses = {}
    losses['cross_entropy_loss'] = torch.nn.CrossEntropyLoss()
    losses['type_ce_loss'] = torch.nn.CrossEntropyLoss()
    losses['color_ce_loss'] = torch.nn.CrossEntropyLoss()
    losses['triplet_hard_loss'] = triplet_hard_loss(
        margin=Config.triplet_margin)

    for k in losses.keys():
        losses[k] = losses[k].cuda()

    start_epoch = 0
    if resume_from_checkpoint and os.path.exists(Config.checkpoint_path):
        checkpoint = load_checkpoint()
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

    # continue training for next the epoch of the checkpoint, or simply start from 1
    start_epoch += 1

    ret = {
        'start_epoch': start_epoch,
        'model': model,
        'train_loader': train_loader,
        'test_loader': test_loader,
        'optimizer': optimizer,
        'scheduler': scheduler,
        'losses': losses
    }

    prepare_end_time = time.time()
    time_spent = sec2min_sec(prepare_start_time, prepare_end_time)
    logger.info(
        'global', 'Finish preparing, time spend: {}mins {}s.'.format(
            time_spent[0], time_spent[1]))

    return ret