class Solver(object): def __init__(self, config, train_data_loader, test_data_loader, val_data_loader=None): # Data loader self.train_data_loader = train_data_loader self.val_data_loader = val_data_loader self.test_data_loader = test_data_loader self.config = config self.build_model() # Build tensorboard if use if self.config.use_tensorboard: self.build_tensorboard() # Start with trained model if self.config.pretrained_model is not None: self.load_pretrained_model() def load_pretrained_model(self): self.protonet.load_state_dict( torch.load( os.path.join(self.config.model_save_dir, self.config.pretrained_model))) def to_var(self, x, volatile=False): if torch.cuda.is_available(): x = x.cuda() return Variable(x, volatile=volatile) def build_model(self): # self.protonet = ProtoNet(x_dim=3, hid_dim=64, out_dim=64) self.protonet = ProtoNet_resnet(layers=[3, 4, 6, 3]) print(self.protonet) # Optimizer self.optimizer = torch.optim.Adam(self.protonet.parameters(), self.config.lr) # Learning rate scheduler self.lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer=self.optimizer, gamma=self.config.lr_scheduler_gamma, step_size=self.config.lr_scheduler_step) if torch.cuda.is_available(): self.protonet.cuda() def build_tensorboard(self): from tools.logger import Logger self.logger = Logger(self.config.log_dir) def train(self): train_loss = [] train_acc = [] val_loss = [] val_acc = [] best_acc = 0 best_model_path = os.path.join(self.config.model_save_dir, 'best_model.pth') last_model_path = os.path.join(self.config.model_save_dir, 'last_model.pth') # Start training start_time = time.time() for e in range(self.config.num_epochs): print('=== Epoch: {} ==='.format(e)) for i, (images, labels) in tqdm(enumerate(self.train_data_loader)): self.optimizer.zero_grad() images = self.to_var(images) labels = self.to_var(labels) features = self.protonet(images) loss, acc = self.protonet.loss( samples=features, labels=labels, num_way=self.config.num_train_way, num_support=self.config.num_train_support, num_query=self.config.num_train_query) loss.backward() self.optimizer.step() train_loss.append(loss.item()) train_acc.append(acc.item()) train_avg_loss = np.mean( train_loss[-self.config.num_train_episodes:]) train_avg_acc = np.mean( train_acc[-self.config.num_train_episodes:]) print('Avg Train Loss: {}, Avg Train Acc: {}'.format( train_avg_loss, train_avg_acc)) self.lr_scheduler.step() # Logging log = {} log['train_avg_loss'] = train_avg_loss log['train_avg_acc'] = train_avg_acc if self.val_data_loader is not None: # Start Validating for images, labels in self.val_data_loader: images = self.to_var(images) labels = self.to_var(labels) features = self.protonet(images) loss, acc = self.protonet.loss( samples=features, labels=labels, num_way=self.config.num_train_way, num_support=self.config.num_train_support, num_query=self.config.num_train_query) val_loss.append(loss.item()) val_acc.append(acc.item()) val_avg_loss = np.mean( val_loss[-self.config.num_train_episodes:]) val_avg_acc = np.mean( val_acc[-self.config.num_train_episodes:]) postfix = ' (Best)' if val_avg_acc >= best_acc else ' (Best: {})'.format( best_acc) print('Avg Val Loss: {}, Avg Val Acc: {}{}'.format( val_avg_loss, val_avg_acc, postfix)) if val_avg_acc >= best_acc: torch.save(self.protonet.state_dict(), best_model_path) best_acc = val_avg_acc log['val_avg_loss'] = val_avg_loss log['val_avg_acc'] = val_avg_acc elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed))[:-7] print('Time:{}'.format(elapsed)) if self.config.use_tensorboard: for tag, value in log.items(): self.logger.scalar_summary(tag, value, e + 1) torch.save(self.protonet.state_dict(), last_model_path) def test(self): print('Loading Model......') best_model_path = os.path.join(self.config.model_save_dir, 'best_model.pth') self.protonet.load_state_dict(torch.load(best_model_path)) avg_acc = list() for e in range(5): print('== Epoch:{} =='.format(e)) for images, labels in self.test_data_loader: images = self.to_var(images) labels = self.to_var(labels) features = self.protonet(images) loss, acc = self.protonet.loss( samples=features, labels=labels, num_way=self.config.num_train_way, num_support=self.config.num_train_support, num_query=self.config.num_train_query) avg_acc.append(acc.item()) avg_acc = np.mean(avg_acc) print('Test Acc: {}'.format(avg_acc)) return avg_acc
def main(): # opt.manualSeed = random.randint(1, 10000) opt.manualSeed = 1 # random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) torch.cuda.manual_seed(opt.manualSeed) np.random.seed(opt.manualSeed) if opt.tanh: opt.logdir = opt.logdir + '_tanh' if opt.grid: opt.logdir = opt.logdir + '_grid' outputdir = os.path.join(opt.logdir, 'gridlen%r_gridnum%d'%(opt.grid_len, opt.grid_num), 'bs%d_wd%s_lr%r_lamb%r_ratio%r_posi%r_%s'%(opt.batch_size, opt.wd, opt.lr, opt.lamb, opt.ratio, \ opt.posi_ratio, opt.optimizer)) if not os.path.exists(outputdir): os.makedirs(outputdir) LOG_FOUT = open(os.path.join(outputdir, 'log.txt'), 'w') def log_string(out_str): LOG_FOUT.write(out_str + '\n') LOG_FOUT.flush() print(out_str) log_train = Logger(outputdir) tb_log = Logger(outputdir) log_string(str(opt) + '\n') grid_len = opt.grid_len / 100 grid_num = opt.grid_num dataset = GraspData(opt.dataset_root, sample_ratio=opt.ratio, posi_ratio=opt.posi_ratio, grid_len=grid_len, grid_num=grid_num) dataLoader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers, pin_memory=True) net = GraspPoseNet(tanh=opt.tanh, grid=opt.grid, bn=False).cuda() lr = opt.lr params = [] for item in net.named_parameters(): key, value = item[0], item[1] if value.requires_grad: params += [{'params': [value], 'lr': lr, 'weight_decay': opt.wd}] if opt.optimizer == "adam": optimizer = torch.optim.Adam(params) elif opt.optimizer == "sgd": optimizer = torch.optim.SGD(params, momentum=opt.momentum) if opt.resume is not None: if os.path.isfile(opt.resume): print("=> loading checkpoint '{}'".format(opt.resume)) checkpoint = torch.load(opt.resume) start_epoch = checkpoint['epoch'] + 1 best_loss = checkpoint['best_loss'] model_dict = net.state_dict() pretrained_dict = checkpoint['state_dict'] pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) net.load_state_dict(model_dict) optimizer.load_state_dict(checkpoint['optimizer']) # lr = checkpoint['lr'] print("\n=> loaded checkpoint '{}' (epoch {})".format( opt.resume, checkpoint['epoch'])) del checkpoint else: assert False, 'WRONG RESUME PATH!' else: start_epoch = opt.start_epoch best_loss = 100000.0 lr = opt.lr # criterion = Loss(opt.num_points_mesh, opt.sym_list) score_criterion = nn.BCELoss().cuda() reg_criterion = nn.MSELoss(reduce=False).cuda() best_test = np.Inf st_time = time.time() for epoch in range(start_epoch, opt.nepoch + 1): # lr = adjust_learning_rate(optimizer, epoch, opt.lr) net.train() loss_sum = 0 prop_loss_sum = 0 score_loss_sum = 0 ang_loss_sum = 0 off_loss_sum = 0 anti_acc_sum = 0 anti_recall_sum = 0 grasp_acc_sum = 0 grasp_recall_sum = 0 prop_acc_sum = 0 prop_recall_sum = 0 loss_epoch = 0 prop_loss_epoch = 0 score_loss_epoch = 0 ang_loss_epoch = 0 off_loss_epoch = 0 anti_acc_epoch = 0 anti_recall_epoch = 0 grasp_acc_epoch = 0 grasp_recall_epoch = 0 prop_acc_epoch = 0 prop_recall_epoch = 0 for i, data in enumerate(dataLoader): pc_, grids_, contact_, center_, contact_index_, scores_, grasps_idx_, angles_, posi_mask_, \ angles_scorer_, posi_nega_idx_ = data print(grids_.size(), contact_.size(), pc_.size(), angles_.size()) if contact_index_.size( 1) == 1 or pc_.size(1) > 20000 or pc_.size(1) < 10: continue st = time.time() # Due to the limit of GPU memory, we need two GPU. One for model training, another for grasp proposal. pc1, grids1, contact_index1, center1, scores1 = pc_.float().cuda(1), grids_.float().cuda(1), \ contact_index_.long().cuda(1), center_.float().cuda(1), scores_.float().cuda(1) pc, grids, angles, contact_index, center, scores, grasps_idx, posi_mask = \ pc_.float().cuda(0), grids_.float().cuda(0), angles_.float().cuda(0), \ contact_index_.long().cuda(0), center_.float().cuda(0), \ scores_.float().cuda(0), grasps_idx_.long().cuda(0), posi_mask_.float().cuda(0) data_index = torch.arange(contact_index_.size(1)).long().cuda() radius = grid_len / grid_num * np.sqrt(3) print('start proposal') pairs_all_, scores_all_, offsets_all_, local_points_, data_index_, prop_label_, posi_prop_idx_, \ nega_prop_idx_, posi_idx_, nega_idx_ = getProposals(pc1, grids1, center1, contact_index1, \ scores1, data_index, radius=radius) del (grids1, center1, contact_index1, scores1) pairs_all, scores_all, offsets_all, local_points, data_index, prop_label, posi_prop_idx, nega_prop_idx, \ posi_idx, nega_idx = pairs_all_.cuda(0), scores_all_.cuda(0), offsets_all_.cuda(0), local_points_.cuda(0), \ data_index_.cuda(0), prop_label_.cuda(0), posi_prop_idx_.cuda(0), nega_prop_idx_.cuda(0), \ posi_idx_.cuda(0), nega_idx_.cuda(0) print('proposal time: ', time.time() - st, 'posi-nega num: ', posi_idx.size(0), nega_idx.size(0)) if scores_all.max() == 0 or scores_all.min() > 0 or posi_idx.size(0) == 0 or nega_idx.size(0) == 0 or \ nega_prop_idx.size(0) == 0 or posi_prop_idx.size(0) == 0: continue grasp_center_ = center_[:, posi_nega_idx_[0]].float() grasp_contact_ = contact_[:, posi_nega_idx_[0]].float() grasp_angle_ = angles_scorer_[:, posi_nega_idx_[0]].float() grasp_center1 = grasp_center_.cuda(1) grasp_contact1 = grasp_contact_.cuda(1) grasp_local_points_ = getLocalPoints(pc1, grasp_contact1, grasp_center1) grasp_local_points = grasp_local_points_.cuda(0).long() grasp_center, grasp_angle = grasp_center_.cuda( 0), grasp_angle_.cuda(0).unsqueeze(-1) grasp_label = scores_[:, posi_nega_idx_[0]].float().cuda(0) del (pc_, grids_, contact_, center_, contact_index_, scores_, grasps_idx_, angles_, local_points_, \ offsets_all_, scores_all_, pairs_all_, data_index_, posi_prop_idx_, nega_prop_idx_, posi_idx_, \ nega_idx_, grasp_local_points_, grasp_center1, grasp_contact1) prop_score, pred_score, pred_offset, pred_angle, posi_prop_idx, nega_prop_idx, posi_idx, nega_idx\ = net(pc, local_points, pairs_all, posi_prop_idx, nega_prop_idx, posi_idx, nega_idx, grasp_center, grasp_angle, grasp_local_points) prop_label = prop_label[:, torch.cat([ posi_prop_idx.view(-1), nega_prop_idx.view(-1) ], 0)] select_idx = torch.cat([posi_idx.view(-1), nega_idx.view(-1)], 0) select_data = data_index[:, select_idx].view(-1) gt_score = scores_all[:, select_idx] gt_offset = offsets_all[:, select_idx] gt_angle = angles[:, select_data].squeeze(0) grasps_idx = grasps_idx[:, select_data].squeeze(0) posi_mask = posi_mask[:, select_data].squeeze(0) prop_acc, prop_recall = cal_accuracy(prop_label, prop_score, recall=True) gt_label = (gt_score > 0).float() grasp_acc, grasp_recall = cal_accuracy(grasp_label, pred_score, recall=True) prop_loss = score_criterion(prop_score, prop_label) print('proposal score: ', prop_score.max().item(), prop_score.min().item(), prop_label.max().item(), prop_label.min().item()) print('grasp score: ', pred_score.max().item(), pred_score.min().item(), gt_score.max().item(), gt_score.min().item()) score_loss = score_criterion(pred_score, grasp_label) print( 'angle: %6f %6f' % (pred_angle.min().item(), pred_angle.max().item()), 'offsets: %6f %6f' % (pred_offset.min().item(), pred_offset.max().item())) posi_gt = torch.nonzero(gt_score.view(-1)).view(-1) posi_score = gt_score[0, posi_gt] ang_loss = angle_loss(pred_angle[0][posi_gt].unsqueeze(-1), gt_angle[posi_gt].unsqueeze(-1), posi_mask[posi_gt].unsqueeze(-1)) ang_loss = torch.sum(posi_score * ang_loss) / posi_score.sum() off_loss = torch.sum(gt_score * reg_criterion( pred_offset, gt_offset).sum(-1)) / gt_score.sum() all_loss = prop_loss + score_loss + opt.lamb * ang_loss + off_loss optimizer.zero_grad() all_loss.backward() optimizer.step() # print(time.time() - st) loss_sum += all_loss.item() prop_loss_sum += prop_loss.item() score_loss_sum += score_loss.item() ang_loss_sum += ang_loss.item() off_loss_sum += off_loss.item() prop_acc_sum += prop_acc prop_recall_sum += prop_recall grasp_acc_sum += grasp_acc grasp_recall_sum += grasp_recall loss_epoch += all_loss.item() prop_loss_epoch += prop_loss.item() score_loss_epoch += score_loss.item() ang_loss_epoch += ang_loss.item() off_loss_epoch += off_loss.item() prop_acc_epoch += prop_acc prop_recall_epoch += prop_recall grasp_acc_epoch += grasp_acc grasp_recall_epoch += grasp_recall del (all_loss, prop_loss, score_loss, ang_loss, off_loss) if i % opt.print_freq == 0: loss_sum /= opt.print_freq prop_loss_sum /= opt.print_freq score_loss_sum /= opt.print_freq ang_loss_sum /= opt.print_freq off_loss_sum /= opt.print_freq prop_acc_sum /= opt.print_freq grasp_acc_sum /= opt.print_freq prop_recall_sum /= opt.print_freq grasp_recall_sum /= opt.print_freq log_string('Epoch: [{0}][{1}/{2}]\t' 'all_loss: {Loss:.4f} ' 'prop_loss: {prop_loss:.4f} ' 'score_loss: {score_loss:.4f} ' 'ang_loss: {ang_loss:.4f} ' 'off_loss: {off_loss:.4f}\t' 'prop_acc: {prop_acc:.4f} ' 'prop_recall: {prop_recall:.4f} ' 'grasp_acc: {grasp_acc:.4f} ' 'grasp_recall: {grasp_recall:.4f}\t' 'lr: {lr:.5f}\t'.format( epoch, i, len(dataLoader), Loss=loss_sum, prop_loss=prop_loss_sum, score_loss=score_loss_sum, \ ang_loss=ang_loss_sum, off_loss=off_loss_sum, prop_acc=prop_acc_sum, grasp_acc=grasp_acc_sum, \ prop_recall=prop_recall_sum, grasp_recall=grasp_recall_sum, lr=lr)) loss_sum = 0 prop_loss_sum = 0 score_loss_sum = 0 ang_loss_sum = 0 off_loss_sum = 0 prop_acc_sum = 0 grasp_acc_sum = 0 prop_recall_sum = 0 grasp_recall_sum = 0 loss_epoch /= len(dataLoader) prop_loss_epoch /= len(dataLoader) score_loss_epoch /= len(dataLoader) ang_loss_epoch /= len(dataLoader) off_loss_epoch /= len(dataLoader) prop_acc_epoch /= len(dataLoader) grasp_acc_epoch /= len(dataLoader) prop_recall_epoch /= len(dataLoader) grasp_recall_epoch /= len(dataLoader) tb_log.scalar_summary('train_loss/all_loss', loss_epoch, epoch) tb_log.scalar_summary('train_loss/prop_loss', prop_loss_epoch, epoch) tb_log.scalar_summary('train_loss/score_loss', score_loss_epoch, epoch) tb_log.scalar_summary('train_loss/angle_loss', ang_loss_epoch, epoch) tb_log.scalar_summary('train_loss/offset_loss', off_loss_epoch, epoch) tb_log.scalar_summary('train_acc/prop_accuracy', prop_acc_epoch, epoch) tb_log.scalar_summary('train_acc/grasp_accuracy', grasp_acc_epoch, epoch) tb_log.scalar_summary('train_acc/prop_recall', prop_recall_epoch, epoch) tb_log.scalar_summary('train_acc/grasp_recall', grasp_recall_epoch, epoch) best_loss = 1000.0 is_best = False if epoch % 5 == 0: checkpoint_dict = { 'epoch': epoch, 'state_dict': net.state_dict(), 'best_loss': best_loss, 'lr': lr, 'optimizer': optimizer.state_dict() } save_checkpoint(checkpoint_dict, is_best, outputdir, epoch)