def test(args,meta_model,optimizer,test_loader,train_epoch,return_val=False,inner_steps=10,seed= 0): ''' Meta-Testing ''' mode='Test' test_graph_id_local = 0 test_graph_id_global = 0 args.resplit = False epoch=0 args.final_test = False inner_test_auc_array = None inner_test_ap_array = None if return_val: args.inner_steps = inner_steps args.final_test = True inner_test_auc_array = np.zeros((len(test_loader)*args.test_batch_size, int(1000/5))) inner_test_ap_array = np.zeros((len(test_loader)*args.test_batch_size, int(1000/5))) meta_loss = torch.Tensor([0]) test_avg_auc_list, test_avg_ap_list = [], [] test_inner_avg_auc_list, test_inner_avg_ap_list = [], [] for j,data in enumerate(test_loader): if args.adamic_adar_baseline: # Val Ratio is Fixed at 0.1 meta_test_edge_ratio = 1 - args.meta_val_edge_ratio - args.meta_train_edge_ratio data = meta_model.split_edges(data[0],val_ratio=args.meta_val_edge_ratio,\ test_ratio=meta_test_edge_ratio) G_test = create_nx_graph(data) auc, ap = calc_adamic_adar_score(G_test,data.test_pos_edge_index,data.test_neg_edge_index) test_avg_auc_list.append(auc) test_avg_ap_list.append(ap) test_graph_id_global += 1 continue if args.deepwalk_baseline or args.deepwalk_and_mlp: # Val Ratio is Fixed at 0.2 meta_test_edge_ratio = 1 - args.meta_val_edge_ratio - args.meta_train_edge_ratio data = meta_model.split_edges(data[0], val_ratio=args.meta_val_edge_ratio, \ test_ratio=meta_test_edge_ratio) G = create_nx_graph_deepwalk(data) node_vectors, entity2index, index2entity = train_deepwalk_model(G,seed=seed) if args.deepwalk_and_mlp: early_stopping = EarlyStopping(patience=args.patience, verbose=False) input_dim = args.num_features + node_vectors.shape[1] mlp = MLPEncoder(args, input_dim, args.num_channels).to(args.dev) mlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=args.mlp_lr) # node1 = data.x[torch.tensor(list(entity2index.keys())).long()] all_node_list = list(range(0, len(data.x))) node_order = [entity2index[node_i] for node_i in all_node_list] node1 = torch.tensor(node_vectors[node_order]) for mlp_epochs in range(0, args.epochs): mlp_optimizer.zero_grad() # node_inp = torch.cat([torch.tensor(node_vectors), node1], dim=1) node_inp = torch.cat([data.x, node1], dim=1) node_inp = node_inp.to(args.dev) z = mlp(node_inp, edge_index=None) loss = meta_model.recon_loss(z, data.train_pos_edge_index.cuda()) loss.backward() mlp_optimizer.step() if mlp_epochs % 10 == 0: if mlp_epochs % 50 == 0: print("Epoch %d, Loss: %f" %(mlp_epochs, loss)) with torch.no_grad(): val_auc, val_ap = meta_model.test(z, data.val_pos_edge_index, data.val_neg_edge_index) early_stopping(val_auc, meta_model) if early_stopping.early_stop: print("Early stopping for Graph %d | AUC: %f AP: %f" \ %(test_graph_id_global, val_auc, val_ap)) break node_inp = torch.cat([data.x, node1], dim=1) # node_inp = torch.cat([torch.tensor(node_vectors), node1], dim=1) node_inp = node_inp.to(args.dev) node_vectors = mlp(node_inp, edge_index=None) auc, ap = meta_model.test(z, data.test_pos_edge_index, data.test_neg_edge_index) else: node_vectors = node_vectors.detach().cpu().numpy() auc, ap = calc_deepwalk_score(data.test_pos_edge_index, data.test_neg_edge_index, node_vectors,entity2index) print("Graph %d| Test AUC: %f AP: %f" %(test_graph_id_global, auc, ap)) test_avg_auc_list.append(auc) test_avg_ap_list.append(ap) test_graph_id_global += 1 continue if not args.random_baseline and not args.adamic_adar_baseline: test_graph_id_local, meta_loss, test_inner_avg_auc_list, test_inner_avg_ap_list = meta_gradient_step(meta_model,\ args,data,optimizer,args.inner_steps,args.inner_lr,args.order,test_graph_id_local,mode,\ test_inner_avg_auc_list, test_inner_avg_ap_list,epoch,j,False,\ inner_test_auc_array,inner_test_ap_array) auc_list, ap_list = global_test(args,meta_model,data,OrderedDict(meta_model.named_parameters())) test_avg_auc_list.append(sum(auc_list)/len(auc_list)) test_avg_ap_list.append(sum(ap_list)/len(ap_list)) ''' Test Logging ''' if args.comet: if len(auc_list) > 0 and len(ap_list) > 0: auc_metric = 'Test_Outer_Batch_Graph_' + str(j) +'_AUC' ap_metric = 'Test_Outer_Batch_Graph_' + str(j) +'_AP' args.experiment.log_metric(auc_metric,sum(auc_list)/len(auc_list),step=train_epoch) args.experiment.log_metric(ap_metric,sum(ap_list)/len(ap_list),step=train_epoch) if args.wandb: if len(auc_list) > 0 and len(ap_list) > 0: auc_metric = 'Test_Outer_Batch_Graph_' + str(j) +'_AUC' ap_metric = 'Test_Outer_Batch_Graph_' + str(j) +'_AP' wandb.log({auc_metric:sum(auc_list)/len(auc_list),\ ap_metric:sum(ap_list)/len(ap_list),"x":epoch},commit=False) print("Failed on %d graphs" %(args.fail_counter)) print("Epoch: %d | Test Global Avg Auc %f | Test Global Avg AP %f" \ %(train_epoch, sum(test_avg_auc_list)/len(test_avg_auc_list),\ sum(test_avg_ap_list)/len(test_avg_ap_list))) if args.comet: if len(auc_list) > 0 and len(ap_list) > 0: auc_metric = 'Test_Avg_' +'_AUC' ap_metric = 'Test_Avg_' +'_AP' inner_auc_metric = 'Test_Inner_Avg' +'_AUC' inner_ap_metric = 'Test_Inner_Avg' +'_AP' args.experiment.log_metric(auc_metric,sum(test_avg_auc_list)/len(test_avg_auc_list),step=train_epoch) args.experiment.log_metric(ap_metric,sum(test_avg_ap_list)/len(test_avg_ap_list),step=train_epoch) args.experiment.log_metric(inner_auc_metric,sum(test_inner_avg_auc_list)/len(test_inner_avg_auc_list),step=train_epoch) args.experiment.log_metric(inner_ap_metric,sum(test_inner_avg_ap_list)/len(test_inner_avg_ap_list),step=train_epoch) if args.wandb: if len(test_avg_auc_list) > 0 and len(test_avg_ap_list) > 0: auc_metric = 'Test_Avg' +'_AUC' ap_metric = 'Test_Avg' +'_AP' wandb.log({auc_metric:sum(test_avg_auc_list)/len(test_avg_auc_list),\ ap_metric:sum(test_avg_ap_list)/len(test_avg_ap_list),\ "x":train_epoch},commit=False) if len(test_inner_avg_auc_list) > 0 and len(test_inner_avg_ap_list) > 0: inner_auc_metric = 'Test_Inner_Avg' +'_AUC' inner_ap_metric = 'Test_Inner_Avg' +'_AP' wandb.log({inner_auc_metric:sum(test_inner_avg_auc_list)/len(test_inner_avg_auc_list), inner_ap_metric:sum(test_inner_avg_ap_list)/len(test_inner_avg_ap_list), "x":train_epoch},commit=False) if len(test_inner_avg_ap_list) > 0: print('Epoch {:01d} | Test Inner AUC: {:.4f}, AP: {:.4f}'.format(train_epoch,sum(test_inner_avg_auc_list)/len(test_inner_avg_auc_list),sum(test_inner_avg_ap_list)/len(test_inner_avg_ap_list))) if return_val: test_avg_auc = sum(test_avg_auc_list)/len(test_avg_auc_list) test_avg_ap = sum(test_avg_ap_list)/len(test_avg_ap_list) if len(test_inner_avg_ap_list) > 0: test_inner_avg_auc = sum(test_inner_avg_auc_list)/len(test_inner_avg_auc_list) test_inner_avg_ap = sum(test_inner_avg_ap_list)/len(test_inner_avg_ap_list) #Remove All zero rows test_auc_array = inner_test_auc_array[~np.all(inner_test_auc_array == 0, axis=1)] test_ap_array = inner_test_ap_array[~np.all(inner_test_ap_array == 0, axis=1)] test_aggr_auc = np.sum(test_auc_array,axis=0)/len(test_loader) test_aggr_ap = np.sum(test_ap_array,axis=0)/len(test_loader) max_auc = np.max(test_aggr_auc) max_ap = np.max(test_aggr_ap) auc_metric = 'Test_Complete' +'_AUC' ap_metric = 'Test_Complete' +'_AP' for val_idx in range(0,test_auc_array.shape[1]): auc = test_aggr_auc[val_idx] ap = test_aggr_ap[val_idx] if args.comet: args.experiment.log_metric(auc_metric,auc,step=val_idx) args.experiment.log_metric(ap_metric,ap,step=val_idx) if args.wandb: wandb.log({auc_metric:auc,ap_metric:ap,"x":val_idx}) print("Test Max AUC :%f | Test Max AP: %f" %(max_auc,max_ap)) ''' Save Local final params ''' if not os.path.exists('../saved_models/'): os.makedirs('../saved_models/') save_path = '../saved_models/' + args.namestr + '_local.pt' torch.save(meta_model.state_dict(), save_path) return max_auc, max_ap
LOG.info(model_triplet) pytorch_total_params = sum(p.numel() for p in model_triplet.parameters() if p.requires_grad) LOG.info("number of parameters in the model: {}".format(pytorch_total_params)) model_triplet.train() # scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5, verbose=True) LOG.info(optimizer) model_triplet = to_cuda_if_available(model_triplet) # ########## # # Callbacks # ########## if cfg.save_best: save_best_call = SaveBest(val_comp="sup") if cfg.early_stopping is not None: early_stopping_call = EarlyStopping(patience=cfg.early_stopping, val_comp="sup") # ########## # # Training # ########## save_results = pd.DataFrame() model_name_triplet = base_model_name + "triplet" if cfg.save_best: model_path_pretrain = os.path.join(model_directory, model_name_triplet, "best_model") else: model_path_pretrain = os.path.join(model_directory, model_name_triplet, "epoch_" + str(f_args.epochs)) print("path of model : " + model_path_pretrain) create_folder(os.path.join(model_directory, model_name_triplet))
def meta_gradient_step(model, args, data_batch, optimiser, inner_train_steps, inner_lr, order, graph_id, mode, inner_avg_auc_list, inner_avg_ap_list, epoch, batch_id, train, inner_test_auc_array=None, inner_test_ap_array=None): """ Perform a gradient step on a meta-learner. # Arguments model: Base model of the meta-learner being trained optimiser: Optimiser to calculate gradient step from loss loss_fn: Loss function to calculate between predictions and outputs data_batch: Input samples for all few shot tasks meta-gradients after applying the update to inner_train_steps: Number of gradient steps to fit the fast weights during each inner update inner_lr: Learning rate used to update the fast weights on the inner update order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated weights on the query with respect to the original weights). graph_id: The ID of graph currently being trained train: Whether to update the meta-learner weights at the end of the episode. inner_test_auc_array: Final Test AUC array where we train to convergence inner_test_ap_array: Final Test AP array where we train to convergence """ task_gradients = [] task_losses = [] task_predictions = [] auc_list = [] ap_list = [] torch.autograd.set_detect_anomaly(True) for idx, data_graph in enumerate(data_batch): data_graph.train_mask = data_graph.val_mask = data_graph.test_mask = data_graph.y = None data_graph.batch = None num_nodes = data_graph.num_nodes if args.use_fixed_feats: perm = torch.randperm(args.feats.size(0)) perm_idx = perm[:num_nodes] data_graph.x = args.feats[perm_idx] elif args.use_same_fixed_feats: node_feats = args.feats[0].unsqueeze(0).repeat(num_nodes, 1) data_graph.x = node_feats if args.concat_fixed_feats: if data_graph.x.shape[1] < args.num_features: concat_feats = torch.randn(num_nodes, args.num_concat_features, requires_grad=False) data_graph.x = torch.cat((data_graph.x, concat_feats), 1) # Val Ratio is Fixed at 0.1 meta_test_edge_ratio = 1 - args.meta_val_edge_ratio - args.meta_train_edge_ratio ''' Check if Data is split''' try: x, train_pos_edge_index = data_graph.x.to( args.dev), data_graph.train_pos_edge_index.to(args.dev) data = data_graph except: data_graph.x.cuda() data = model.split_edges(data_graph, val_ratio=args.meta_val_edge_ratio, test_ratio=meta_test_edge_ratio) # Additional Failure Checks for small graphs if data.val_pos_edge_index.size( )[1] == 0 or data.test_pos_edge_index.size()[1] == 0: args.fail_counter += 1 print("Failed on Graph %d" % (graph_id)) continue try: x, train_pos_edge_index = data.x.to( args.dev), data.train_pos_edge_index.to(args.dev) test_pos_edge_index, test_neg_edge_index = data.test_pos_edge_index.to(args.dev),\ data.test_neg_edge_index.to(args.dev) except: print("Failed Splitting data on Graph %d" % (graph_id)) continue data_shape = x.shape[2:] create_graph = (True if order == 2 else False) and train # Create a fast model using the current meta model weights fast_weights = OrderedDict(model.named_parameters()) early_stopping = EarlyStopping(patience=args.patience, verbose=False) # Train the model for `inner_train_steps` iterations for inner_batch in range(inner_train_steps): # Perform update of model weights z = model.encode(x, train_pos_edge_index, fast_weights, inner_loop=True) loss = model.recon_loss(z, train_pos_edge_index) if args.model in ['VGAE']: kl_loss = args.kl_anneal * (1 / num_nodes) * model.kl_loss() loss = loss + kl_loss # print("Inner KL Loss: %f" %(kl_loss.item())) if not args.train_only_gs: gradients = torch.autograd.grad(loss, fast_weights.values(),\ allow_unused=args.allow_unused, create_graph=create_graph) gradients = [0 if grad is None else grad for grad in gradients] if args.wandb: wandb.log({"Inner_Train_loss": loss.item()}) if args.clip_grad: # for grad in gradients: custom_clip_grad_norm_(gradients, args.clip) grad_norm = monitor_grad_norm_2(gradients) if args.wandb: inner_grad_norm_metric = 'Inner_Grad_Norm' wandb.log({inner_grad_norm_metric: grad_norm}) ''' Only do this if its the final test set eval ''' if args.final_test and inner_batch % 5 == 0: inner_test_auc, inner_test_ap = test(model, x, train_pos_edge_index, data.test_pos_edge_index, data.test_neg_edge_index, fast_weights) val_pos_edge_index = data.val_pos_edge_index.to(args.dev) val_loss = val(model, args, x, val_pos_edge_index, data.num_nodes, fast_weights) early_stopping(val_loss, model) my_step = int(inner_batch / 5) inner_test_auc_array[graph_id][my_step] = inner_test_auc inner_test_ap_array[graph_id][my_step] = inner_test_ap # Update weights manually if not args.train_only_gs and args.clip_weight: fast_weights = OrderedDict( (name, torch.clamp((param - inner_lr * grad), -args.clip_weight_val, args.clip_weight_val)) for ((name, param), grad) in zip(fast_weights.items(), gradients)) elif not args.train_only_gs: fast_weights = OrderedDict( (name, param - inner_lr * grad) for ((name, param), grad) in zip(fast_weights.items(), gradients)) if early_stopping.early_stop: print("Early stopping for Graph %d | AUC: %f AP: %f" \ %(graph_id, inner_test_auc, inner_test_ap)) my_step = int(epoch / 5) inner_test_auc_array[graph_id][my_step:, ] = inner_test_auc inner_test_ap_array[graph_id][my_step:, ] = inner_test_ap break # Do a pass of the model on the validation data from the current task val_pos_edge_index = data.val_pos_edge_index.to(args.dev) z_val = model.encode(x, val_pos_edge_index, fast_weights, inner_loop=False) loss_val = model.recon_loss(z_val, val_pos_edge_index) if args.model in ['VGAE']: kl_loss = args.kl_anneal * (1 / num_nodes) * model.kl_loss() # print("Outer KL Loss: %f" %(kl_loss.item())) loss_val = loss_val + kl_loss if args.wandb: wandb.log({"Inner_Val_loss": loss_val.item()}) # print("Inner Val Loss %f" % (loss_val.item())) ##TODO: Is this backward call needed here? Not sure because original repo has it # https://github.com/oscarknagg/few-shot/blob/master/few_shot/maml.py#L84 if args.extra_backward: loss_val.backward(retain_graph=True) # Get post-update accuracies auc, ap = test(model, x, train_pos_edge_index, data.test_pos_edge_index, data.test_neg_edge_index, fast_weights) auc_list.append(auc) ap_list.append(ap) inner_avg_auc_list.append(auc) inner_avg_ap_list.append(ap) # Accumulate losses and gradients graph_id += 1 task_losses.append(loss_val) if order == 1: gradients = torch.autograd.grad(loss_val, fast_weights.values(), create_graph=create_graph) named_grads = { name: g for ((name, _), g) in zip(fast_weights.items(), gradients) } task_gradients.append(named_grads) if len(auc_list) > 0 and len(ap_list) > 0 and batch_id % 5 == 0: print( 'Epoch {:01d} Inner Graph Batch: {:01d}, Inner-Update AUC: {:.4f}, AP: {:.4f}' .format(epoch, batch_id, sum(auc_list) / len(auc_list), sum(ap_list) / len(ap_list))) if args.comet: if len(ap_list) > 0: auc_metric = mode + '_Local_Batch_Graph_' + str(batch_id) + '_AUC' ap_metric = mode + '_Local_Batch_Graph_' + str(batch_id) + '_AP' avg_auc_metric = mode + '_Inner_Batch_Graph' + '_AUC' avg_ap_metric = mode + '_Inner_Batch_Graph' + '_AP' args.experiment.log_metric(auc_metric, sum(auc_list) / len(auc_list), step=epoch) args.experiment.log_metric(ap_metric, sum(ap_list) / len(ap_list), step=epoch) args.experiment.log_metric(avg_auc_metric, sum(auc_list) / len(auc_list), step=epoch) args.experiment.log_metric(avg_ap_metric, sum(ap_list) / len(ap_list), step=epoch) if args.wandb: if len(ap_list) > 0: auc_metric = mode + '_Local_Batch_Graph_' + str(batch_id) + '_AUC' ap_metric = mode + '_Local_Batch_Graph_' + str(batch_id) + '_AP' avg_auc_metric = mode + '_Inner_Batch_Graph' + '_AUC' avg_ap_metric = mode + '_Inner_Batch_Graph' + '_AP' wandb.log({auc_metric:sum(auc_list)/len(auc_list),ap_metric:sum(ap_list)/len(ap_list),\ avg_auc_metric:sum(auc_list)/len(auc_list),avg_ap_metric:sum(ap_list)/len(ap_list)}) meta_batch_loss = torch.Tensor([0]) if order == 1: if train and len(task_losses) != 0: sum_task_gradients = { k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0) for k in task_gradients[0].keys() } hooks = [] for name, param in model.named_parameters(): hooks.append( param.register_hook(replace_grad(sum_task_gradients, name))) model.train() optimiser.zero_grad() # Dummy pass in order to create `loss` variable # Replace dummy gradients with mean task gradients using hooks ## TODO: Double check if you really need functional forward here z_dummy = model.encode(torch.zeros(x.shape[0],x.shape[1]).float().cuda(), \ torch.zeros(train_pos_edge_index.shape[0],train_pos_edge_index.shape[1]).long().cuda(), fast_weights) loss = model.recon_loss(z_dummy,torch.zeros(train_pos_edge_index.shape[0],\ train_pos_edge_index.shape[1]).long().cuda()) loss.backward() optimiser.step() for h in hooks: h.remove() meta_batch_loss = torch.stack(task_losses).mean() return graph_id, meta_batch_loss, inner_avg_auc_list, inner_avg_ap_list elif order == 2: if len(task_losses) != 0: model.train() optimiser.zero_grad() meta_batch_loss = torch.stack(task_losses).mean() if train: meta_batch_loss.backward() if args.clip_grad: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) grad_norm = monitor_grad_norm(model) if args.wandb: outer_grad_norm_metric = 'Outer_Grad_Norm' wandb.log({outer_grad_norm_metric: grad_norm}) optimiser.step() if args.clip_weight: for p in model.parameters(): p.data.clamp_(-args.clip_weight_val, args.clip_weight_val) return graph_id, meta_batch_loss, inner_avg_auc_list, inner_avg_ap_list else: raise ValueError('Order must be either 1 or 2.')
def train_classifier(train_loader, classif_model, optimizer_classif, many_hot_encoder=None, valid_loader=None, state={}, dir_model="model", result_path="res", recompute=True): criterion_bce = nn.BCELoss() classif_model, criterion_bce = to_cuda_if_available(classif_model, criterion_bce) print(classif_model) early_stopping_call = EarlyStopping(patience=cfg.early_stopping, val_comp="sup", init_patience=cfg.first_early_wait) save_best_call = SaveBest(val_comp="sup") # scheduler = ReduceLROnPlateau(optimizer_classif, 'max', factor=0.1, patience=cfg.reduce_lr, # verbose=True) print(optimizer_classif) save_results = pd.DataFrame() create_folder(dir_model) if cfg.save_best: model_path_sup1 = os.path.join(dir_model, "best_model") else: model_path_sup1 = os.path.join(dir_model, "epoch_" + str(cfg.n_epoch_classifier)) print("path of model : " + model_path_sup1) state['many_hot_encoder'] = many_hot_encoder.state_dict() if not os.path.exists(model_path_sup1) or recompute: for epoch_ in range(cfg.n_epoch_classifier): print(classif_model.training) start = time.time() loss_mean_bce = [] for i, samples in enumerate(train_loader): inputs, pred_labels = samples if i == 0: LOG.debug("classif input shape: {}".format(inputs.shape)) # zero the parameter gradients optimizer_classif.zero_grad() inputs = to_cuda_if_available(inputs) # forward + backward + optimize weak_out = classif_model(inputs) weak_out = to_cpu(weak_out) # print(output) loss_bce = criterion_bce(weak_out, pred_labels) loss_mean_bce.append(loss_bce.item()) loss_bce.backward() optimizer_classif.step() loss_mean_bce = np.mean(loss_mean_bce) classif_model.eval() n_class = len(many_hot_encoder.labels) macro_f_measure_train = get_f_measure_by_class(classif_model, n_class, train_loader) if valid_loader is not None: macro_f_measure = get_f_measure_by_class(classif_model, n_class, valid_loader) mean_macro_f_measure = np.mean(macro_f_measure) else: mean_macro_f_measure = -1 classif_model.train() print("Time to train an epoch: {}".format(time.time() - start)) # print statistics print('[%d / %d, %5d] loss: %.3f' % (epoch_ + 1, cfg.n_epoch_classifier, i + 1, loss_mean_bce)) results = {"train_loss": loss_mean_bce, "macro_measure_train": np.mean(macro_f_measure_train), "class_macro_train": np.array_str(macro_f_measure_train, precision=2), "macro_measure_valid": mean_macro_f_measure, "class_macro_valid": np.array_str(macro_f_measure, precision=2), } for key in results: LOG.info("\t\t ----> {} : {}".format(key, results[key])) save_results = save_results.append(results, ignore_index=True) # scheduler.step(mean_macro_f_measure) # ########## # # Callbacks # ########## state['epoch'] = epoch_ + 1 state["model"]["state_dict"] = classif_model.state_dict() state["optimizer"]["state_dict"] = optimizer_classif.state_dict() state["loss"] = loss_mean_bce state.update(results) if cfg.early_stopping is not None: if early_stopping_call.apply(mean_macro_f_measure): print("EARLY STOPPING") break if cfg.save_best and save_best_call.apply(mean_macro_f_measure): save_model(state, model_path_sup1) if cfg.save_best: LOG.info( "best model at epoch : {} with macro {}".format(save_best_call.best_epoch, save_best_call.best_val)) LOG.info("loading model from: {}".format(model_path_sup1)) classif_model, state = load_model(model_path_sup1, return_optimizer=False, return_state=True) else: model_path_sup1 = os.path.join(dir_model, "epoch_" + str(cfg.n_epoch_classifier)) save_model(state, model_path_sup1) LOG.debug("model path: {}".format(model_path_sup1)) LOG.debug('Finished Training') else: classif_model, state = load_model(model_path_sup1, return_optimizer=False, return_state=True) LOG.info("#### End classif") save_results.to_csv(result_path, sep="\t", header=True, index=False) return classif_model, state
def gat_run(data,gpu=-1,filename='result/gat_run.txt',num_out_heads=1,num_layers=1,num_hidden=8,num_heads=8,epochs=200, lr=0.005,weight_decay=5e-4, in_drop=.6,attn_drop=.6,negative_slope=.2,residual=False,early_stop=False,fastmode=False): """ INPUT: data : list : see the Unpack process gpu : int : the number of the device, -1 means cpu OUTPUT: """ from nets.gats import GAT file = open(filename,'w') if gpu == -1: device = torch.device('cpu') else: device = torch.device('cuda:%d' % gpu) # Unpack the data n_classes,n_edges,g,num_feats = data g = g.int().to(device) features = g.ndata['feat'].to(device) labels = g.ndata['label'] train_mask = g.ndata['train_mask'] val_mask = g.ndata['val_mask'] test_mask = g.ndata['test_mask'] # create model heads = ([num_heads] * num_layers) + [num_out_heads] # [num_layers个num_heads,最后一个是num_out_head]组成list model = GAT(g, num_layers, num_feats, num_hidden, n_classes, heads, F.elu, in_drop, attn_drop, negative_slope, residual) print(model,file=file) if early_stop: stopper = EarlyStopping(patience=100) model=model.to(device) loss_fcn = torch.nn.CrossEntropyLoss() # use optimizer optimizer = torch.optim.Adam( model.parameters(), lr=lr, weight_decay=weight_decay) # initialize graph dur = [] for epoch in range(epochs): model.train() if epoch >= 3: t0 = time.time() # forward logits = model(features) loss = loss_fcn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() if epoch >= 3: dur.append(time.time() - t0) train_acc = accuracy(logits[train_mask], labels[train_mask]) if fastmode: val_acc = accuracy(logits[val_mask], labels[val_mask]) _, indices = torch.max(logits[val_mask], dim=1) f1 = f1_score(labels[val_mask].cpu().numpy(), indices.cpu(), average='micro') else: val_acc = evaluate(model, features, labels, val_mask) _, indices = torch.max(logits[val_mask], dim=1) f1 = f1_score(labels[val_mask].cpu().numpy(), indices.cpu(), average='micro') if early_stop: if stopper.step(val_acc, model): break gpu_mem_alloc = torch.cuda.max_memory_allocated() / 1000000 if torch.cuda.is_available() else 0 print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" " ValAcc {:.4f} | f1-score {:.4f} |ETputs(KTEPS) {:.2f}| GPU : {:.1f} MB ". format(epoch, np.mean(dur), loss.item(), train_acc, val_acc, f1, n_edges / np.mean(dur) / 1000,gpu_mem_alloc),file=file) print() if early_stop: model.load_state_dict(torch.load('es_checkpoint.pt')) # Earlystop不同文件间有耦合 acc = evaluate(model, features, labels, test_mask) _, indices = torch.max(logits[test_mask], dim=1) f1 = f1_score(labels[test_mask].cpu().numpy(), indices.cpu(), average='micro') print("Test Accuracy {:.4f} | F1-score {:.4f}".format(acc,f1),file=file)
def main(args): args['device'] = torch.device("cpu") set_random_seed(args['random_seed']) dataset, train_set, val_set, test_set = load_dataset_for_classification( args) train_loader = DataLoader(train_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) val_loader = DataLoader(val_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) test_loader = DataLoader(test_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) args['n_tasks'] = dataset.n_tasks model = load_model(args) loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to( args['device']), reduction='none') optimizer = Adam(model.parameters(), lr=args['lr']) stopper = EarlyStopping(patience=args['patience']) model.to(args['device']) epochx = 0 losses = [] for epoch in range(args['num_epochs']): # Train loss = run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer) losses.append(loss) # Validation and early stop epochx += 1 val_score = run_an_eval_epoch(args, model, val_loader, epochx, False) early_stop = stopper.step(val_score, model) print( 'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'. format(epoch + 1, args['num_epochs'], args['metric_name'], val_score, args['metric_name'], stopper.best_score)) if early_stop: break stopper.load_checkpoint(model) # Print out the test set score test_score = run_an_eval_epoch(args, model, test_loader, epochx, True) print('test {} {:.4f}'.format(args['metric_name'], test_score)) # Making the loss per epoch figure #print('losses', len(losses)) print(losses) epoch_list = [i + 1 for i in range(len(losses))] ## plt.clf() plt.plot(epoch_list, losses) plt.xlabel("Epoch") plt.ylabel("Loss") plt.rcParams['axes.facecolor'] = 'white' plt.savefig("Loss.Per.Epoch.png")
'state_dict': model.state_dict() }, } optimizer, state = get_optimizer(model, state) criterion_bce = torch.nn.NLLLoss() # torch.nn.BCELoss() model, criterion_bce = to_cuda_if_available(model, criterion_bce) LOG.info(model) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) LOG.info( "number of parameters in the model: {}".format(pytorch_total_params)) early_stopping_call = EarlyStopping(patience=cfg.early_stopping, val_comp="sup", init_patience=cfg.first_early_wait) save_best_call = SaveBest(val_comp="sup") print(optimizer) save_results = pd.DataFrame() model_name_sup = osp.join(model_directory, "classif") create_folder(model_name_sup) if cfg.save_best: model_path_sup1 = os.path.join(model_name_sup, "best_model") else: model_path_sup1 = os.path.join(model_name_sup, "epoch_" + str(n_epochs)) print("path of model : " + model_path_sup1)
def train_crossEntropy(): num_epochs = 800 with open(main_path + '/results.txt', 'w', 1) as output_file: mainModel_stopping = EarlyStopping(patience=300, verbose=True, log_path=main_path, output_file=output_file) classifier_stopping = EarlyStopping(patience=300, verbose=False, log_path=classifier_path, output_file=output_file) print('*****', file=output_file) print('BASELINE', file=output_file) print('transfer - augmentation on both waves and specs - 3 channels', file=output_file) if config.ESC_10: print('ESC_10', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) elif config.ESC_50: print('ESC_50', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) elif config.US8K: print('US8K', file=output_file) print('train folds are {} and test fold is {}'.format( config.us8k_train_folds, config.us8k_test_fold), file=output_file) print('number of freq masks are {} and their max length is {}'.format( config.freq_masks, config.freq_masks_width), file=output_file) print('number of time masks are {} and their max length is {}'.format( config.time_masks, config.time_masks_width), file=output_file) print('*****', file=output_file) for epoch in range(num_epochs): model.train() classifier.train() train_loss = [] train_corrects = 0 train_samples_count = 0 for _, x, label in train_loader: loss = 0 optimizer.zero_grad() inp = x.float().to(device) label = label.to(device).unsqueeze(1) label_vec = hotEncoder(label) y_rep = model(inp) y_rep = F.normalize(y_rep, dim=0) y_pred = classifier(y_rep) loss += cross_entropy_one_hot(y_pred, label_vec) loss.backward() train_loss.append(loss.item()) optimizer.step() train_corrects += (torch.argmax(y_pred, dim=1) == torch.argmax( label_vec, dim=1)).sum().item() train_samples_count += x.shape[0] val_loss = [] val_corrects = 0 val_samples_count = 0 model.eval() classifier.eval() with torch.no_grad(): for _, val_x, val_label in val_loader: inp = val_x.float().to(device) label = val_label.to(device) label_vec = hotEncoder(label) y_rep = model(inp) y_rep = F.normalize(y_rep, dim=0) y_pred = classifier(y_rep) temp = cross_entropy_one_hot(y_pred, label_vec) val_loss.append(temp.item()) val_corrects += (torch.argmax( y_pred, dim=1) == torch.argmax(label_vec, dim=1)).sum().item() val_samples_count += val_x.shape[0] scheduler.step() train_acc = train_corrects / train_samples_count val_acc = val_corrects / val_samples_count print('\n', file=output_file) print("Epoch: {}/{}...".format(epoch + 1, num_epochs), "Loss: {:.4f}...".format(np.mean(train_loss)), "Val Loss: {:.4f}".format(np.mean(val_loss)), file=output_file) print('train_acc is {:.4f} and val_acc is {:.4f}'.format( train_acc, val_acc), file=output_file) mainModel_stopping(-val_acc, main_model, epoch + 1) classifier_stopping(-val_acc, classifier, epoch + 1) if mainModel_stopping.early_stop: print("Early stopping", file=output_file) return
def train_contrastive(): num_epochs = 800 with open(main_path + '/results.txt', 'w', 1) as output_file: mainModel_stopping = EarlyStopping(patience=300, verbose=True, log_path=main_path, output_file=output_file) print('*****', file=output_file) print('Supervised Contrastive Loss', file=output_file) print('temperature for the contrastive loss is {}'.format( config.temperature), file=output_file) if config.ESC_10: print('ESC_10', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) elif config.ESC_50: print('ESC_50', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) elif config.US8K: print('US8K', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) print('number of freq masks are {} and their max length is {}'.format( config.freq_masks, config.freq_masks_width), file=output_file) print('number of time masks are {} and their max length is {}'.format( config.time_masks, config.time_masks_width), file=output_file) print('*****', file=output_file) for epoch in range(num_epochs): model.train() projection_head.train() train_loss = [] for _, x, label in train_loader: batch_loss = 0 optimizer.zero_grad() x = x.to(device) label = label.to(device).unsqueeze(1) label_vec = hotEncoder(label) y_rep = model(x.float()) y_rep = F.normalize(y_rep, dim=0) y_proj = projection_head(y_rep) y_proj = F.normalize(y_proj, dim=0) batch_loss = loss_fn(y_proj.unsqueeze(1), label.squeeze(1)) batch_loss.backward() train_loss.append(batch_loss.item()) optimizer.step() val_loss = [] model.eval() projection_head.eval() with torch.no_grad(): for _, val_x, val_label in val_loader: val_x = val_x.to(device) label = val_label.to(device).unsqueeze(1) label_vec = hotEncoder(label) y_rep = model(val_x.float()) y_rep = F.normalize(y_rep, dim=0) y_proj = projection_head(y_rep) y_proj = F.normalize(y_proj, dim=0) temp = loss_fn(y_proj.unsqueeze(1), label.squeeze(1)) val_loss.append(temp.item()) scheduler.step() print("Epoch: {}/{}...".format(epoch + 1, num_epochs), "Loss: {:.4f}...".format(np.mean(train_loss)), "Val Loss: {:.4f}".format(np.mean(val_loss)), file=output_file) mainModel_stopping(np.mean(val_loss), model, epoch + 1) if mainModel_stopping.early_stop: print("Early stopping", file=output_file) return
data_transforms = transforms.Compose([transforms.RandomResizedCrop(256), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) train_dataset = GarmentDataset(args.rgb_dir, args.data_dir, args.data_type, args.train_label_dir) test_dataset = GarmentDataset(args.rgb_dir, args.data_dir, args.data_type, args.test_label_dir) validate_dataset = GarmentDataset(args.rgb_dir, args.data_dir, args.data_type, args.validate_label_dir) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16) validate_loader = DataLoader(validate_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16) vae = ConvVAE(img_channels=utils.DATA_SIZE, latent_size=utils.LATENT_SIZE).to(device) optimizer = optim.Adam(vae.parameters(), lr=1e-4, weight_decay=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) earlystopping = EarlyStopping('min', patience=30) checkpoint_count = len(os.listdir(args.checkpoint_dir)) reload_dir = os.path.join(args.checkpoint_dir, utils.BEST_FILENAME) if args.generate_sample == 0 and args.reload == 1 and os.path.exists(reload_dir): best_state = torch.load(reload_dir) print('Reloading vae......, file: ', reload_dir) vae.load_state_dict(best_state['state_dict']) optimizer.load_state_dict(best_state['optimizer_dict']) scheduler.load_state_dict(best_state['scheduler_dict']) earlystopping.load_state_dict(best_state['earlystopping_dict']) # delete useless parameter to get more gpu memory del best_state # generate result if args.generate_sample == 1:
def train_classifier(): num_epochs = 800 with open(main_path + '/classifier_results.txt', 'w', 1) as output_file: classifier_stopping = EarlyStopping(patience=300, verbose=True, log_path=classifier_path, output_file=output_file) print('*****', file=output_file) print('classifier after sup_contrastive', file=output_file) if config.ESC_10: print('ESC_10', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) elif config.ESC_50: print('ESC_10', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) elif config.US8K: print('US8K', file=output_file) print('train folds are {} and test fold is {}'.format( config.us8k_train_folds, config.us8k_test_fold), file=output_file) print('number of freq masks are {} and their max length is {}'.format( config.freq_masks, config.freq_masks_width), file=output_file) print('number of time masks are {} and their max length is {}'.format( config.time_masks, config.time_masks_width), file=output_file) print('*****', file=output_file) for epoch in range(num_epochs): classifier.train() train_loss = [] train_corrects = 0 train_samples_count = 0 for _, x, label in train_loader: loss = 0 optimizer.zero_grad() x = x.float().to(device) label = label.to(device).unsqueeze(1) label_vec = hotEncoder(label) y_rep = pretrained_model(x) y_rep = F.normalize(y_rep, dim=0) out = classifier(y_rep) loss = cross_entropy_one_hot(out, label_vec) loss.backward() train_loss.append(loss.item()) optimizer.step() train_corrects += (torch.argmax(out, dim=1) == torch.argmax( label_vec, dim=1)).sum().item() train_samples_count += x.shape[0] val_loss = [] val_acc = [] val_corrects = 0 val_samples_count = 0 classifier.eval() with torch.no_grad(): for _, val_x, val_label in val_loader: val_x = val_x.float().to(device) label = val_label.to(device).unsqueeze(1) label_vec = hotEncoder(label) y_rep = pretrained_model(val_x) y_rep = F.normalize(y_rep, dim=0) out = classifier(y_rep) temp = cross_entropy_one_hot(out, label_vec) val_loss.append(temp.item()) val_corrects += (torch.argmax(out, dim=1) == torch.argmax( label_vec, dim=1)).sum().item() val_samples_count += val_x.shape[0] train_acc = train_corrects / train_samples_count val_acc = val_corrects / val_samples_count scheduler.step() print('\n', file=output_file) print("Epoch: {}/{}...".format(epoch + 1, num_epochs), "Loss: {:.4f}...".format(np.mean(train_loss)), "Val Loss: {:.4f}".format(np.mean(val_loss)), file=output_file) print('train_acc is {:.4f} and val_acc is {:.4f}'.format( train_acc, val_acc), file=output_file) classifier_stopping(-val_acc, classifier, epoch + 1) if classifier_stopping.early_stop: print("Early stopping", file=output_file) return
def train_hybrid(): num_epochs = 800 with open(main_path + '/results.txt', 'w', 1) as output_file: mainModel_stopping = EarlyStopping(patience=300, verbose=True, log_path=main_path, output_file=output_file) classifier_stopping = EarlyStopping(patience=300, verbose=False, log_path=classifier_path, output_file=output_file) print('*****', file=output_file) print('HYBRID', file=output_file) print('alpha is {}'.format(config.alpha), file=output_file) print('temperature of contrastive loss is {}'.format( config.temperature), file=output_file) if config.ESC_10: print('ESC_10', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) elif config.ESC_50: print('ESC_50', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) elif config.US8K: print('US8K', file=output_file) print('train folds are {} and test fold is {}'.format( config.train_folds, config.test_fold), file=output_file) print( 'Freq mask number {} and length {}, and time mask number {} and length is{}' .format(config.freq_masks, config.freq_masks_width, config.time_masks, config.time_masks_width), file=output_file) print('*****', file=output_file) for epoch in range(num_epochs): print('\n' + str(optimizer.param_groups[0]["lr"]), file=output_file) model.train() projection_layer.train() classifier.train() train_loss = [] train_loss1 = [] train_loss2 = [] train_corrects = 0 train_samples_count = 0 for _, x, label in train_loader: loss = 0 optimizer.zero_grad() x = x.float().to(device) label = label.to(device).unsqueeze(1) label_vec = hotEncoder(label) y_rep = model(x) y_rep = F.normalize(y_rep, dim=0) y_proj = projection_layer(y_rep) y_proj = F.normalize(y_proj, dim=0) y_pred = classifier(y_rep) loss1, loss2 = loss_fn(y_proj, y_pred, label, label_vec) loss = loss1 + loss2 torch.autograd.backward([loss1, loss2]) train_loss.append(loss.item()) train_loss1.append(loss1.item()) train_loss2.append(loss2.item()) optimizer.step() train_corrects += (torch.argmax(y_pred, dim=1) == torch.argmax( label_vec, dim=1)).sum().item() train_samples_count += x.shape[0] val_loss = [] val_loss1 = [] val_loss2 = [] val_corrects = 0 val_samples_count = 0 model.eval() projection_layer.eval() classifier.eval() with torch.no_grad(): for _, val_x, val_label in val_loader: val_x = val_x.float().to(device) label = val_label.to(device).unsqueeze(1) label_vec = hotEncoder(label) y_rep = model(val_x) y_rep = F.normalize(y_rep, dim=0) y_proj = projection_layer(y_rep) y_proj = F.normalize(y_proj, dim=0) y_pred = classifier(y_rep) loss1, loss2 = loss_fn(y_proj, y_pred, label, label_vec) loss = loss1 + loss2 val_loss.append(loss.item()) val_loss1.append(loss1.item()) val_loss2.append(loss2.item()) val_corrects += (torch.argmax( y_pred, dim=1) == torch.argmax(label_vec, dim=1)).sum().item() val_samples_count += val_x.shape[0] train_acc = train_corrects / train_samples_count val_acc = val_corrects / val_samples_count scheduler.step() print("\nEpoch: {}/{}...".format(epoch + 1, num_epochs), "Loss: {:.4f}...".format(np.mean(train_loss)), "Val Loss: {:.4f}".format(np.mean(val_loss)), file=output_file) print('train_loss1 is {:.4f} and train_loss2 is {:.4f}'.format( np.mean(train_loss1), np.mean(train_loss2)), file=output_file) print('val_loss1 is {:.4f} and val_loss2 is {:.4f}'.format( np.mean(val_loss1), np.mean(val_loss2)), file=output_file) print('train_acc is {:.4f} and val_acc is {:.4f}'.format( train_acc, val_acc), file=output_file) # add validation checkpoint for early stopping here mainModel_stopping(-val_acc, model, epoch + 1) #proj_stopping(-val_acc, projection_layer, epoch+1) classifier_stopping(-val_acc, classifier, epoch + 1) if mainModel_stopping.early_stop: print("Early stopping", file=output_file) return
def train(model, optimizer, args): #################################################################################################################### # declare variables #################################################################################################################### time_window_ratio = args.time_window_ratio sample_count = args.sample_count populated = args.populated augmented = args.augmented device = torch.device(f'cuda:{args.num_gpu}' if not args.no_cuda and torch.cuda.is_available() else 'cpu') patience = 20 early_stopping_upstream = EarlyStopping(patience=patience, verbose=False, path=args.upstream_model_path) early_stopping_downstream = EarlyStopping(patience=patience, verbose=False, path=args.downstream_model_path) # assert parameters if populated: assert time_window_ratio is not None and sample_count is not None and type(sample_count) == int else: assert time_window_ratio is None and sample_count is None assert not (args.fcn_moco and args.moco_fcn) if args.model_type == 'moco_lstm_fcn_aug' or args.model_type == 'moco_tpa_fcn_aug' \ or args.model_type == 'moco_tpa_fcn' or args.model_type == 'moco_lstm_fcn': args.moco_fcn = True elif args.model_type == 'tpa_fcn_moco': args.fcn_moco = True loss_list = [sys.maxsize] accuracy_list = [] test_best_possible, best_so_far = 0.0, sys.maxsize path = args.data_path + "raw/" + args.dataset + "/" x_train = np.load(path + 'X_train.npy') y_train = np.load(path + 'y_train.npy') x_test = np.load(path + 'X_test.npy') y_test = np.load(path + 'y_test.npy') nclass = int(np.amax(y_train)) + 1 ntimestep = x_train.shape[1] nfeature = x_train.shape[2] #################################################################################################################### # scale dataset #################################################################################################################### l_rng_train = get_channel_range(x_train) l_rng_test = get_channel_range(x_test) print(f"train channel range : {l_rng_train}") print(f"test channel range : {l_rng_test}") if args.scale == 'standard_channel': for c in range(nfeature): scaler = StandardScaler() x_train[:, :, c] = scaler.fit_transform(x_train[:, :, c]) x_test[:, :, c] = scaler.transform(x_test[:, :, c]) elif args.scale == 'minmax_all': for c in range(nfeature): scaler = MinMaxScaler() x_train[:, :, c] = scaler.fit_transform(x_train[:, :, c]) x_test[:, :, c] = scaler.transform(x_test[:, :, c]) elif args.scale == 'standard_all': scaler = StandardScaler() origin_shape = x_train.shape x_train = x_train.reshape(-1, 1) x_train = scaler.fit_transform(x_train) x_train = x_train.reshape(origin_shape) origin_shape = x_test.shape x_test = x_test.reshape(-1, 1) x_test = scaler.transform(x_test) x_test = x_test.reshape(origin_shape) elif args.scale == 'minmax_all': scaler = MinMaxScaler() origin_shape = x_train.shape x_train = x_train.reshape(-1, 1) x_train = scaler.fit_transform(x_train) x_train = x_train.reshape(origin_shape) origin_shape = x_test.shape x_test = x_test.reshape(-1, 1) x_test = scaler.transform(x_test) x_test = x_test.reshape(origin_shape) l_rng_train = get_channel_range(x_train) l_rng_test = get_channel_range(x_test) print(f"after train channel range : {l_rng_train}") print(f"after test channel range : {l_rng_test}") #################################################################################################################### # customize dataset #################################################################################################################### start_time = time.time() print(f"create train and test dataset started..") if populated: train_ds = PopulatedBalancedDataSet(xs=x_train, ys=y_train, n_classes=nclass, time_window=x_train.shape[1], time_window_ratio=args.time_window_ratio, min_sampling_count=args.sample_count) test_ds = PopulatedBalancedDataSet(xs=x_test, ys=y_test, n_classes=nclass, time_window=x_train.shape[1], time_window_ratio=args.time_window_ratio, min_sampling_count=args.sample_count ) # todo : training data에 augmentation을 취해서 ablation test도 추후 진행 # if augmented: # if populated: # populated_x_train = np.asarray(train_ds.df['x'].tolist()) # populated_y_train = np.asarray(train_ds.df['y'].tolist()) # populated_x_test = np.asarray(test_ds.df['x'].tolist()) # populated_y_test = np.asarray(test_ds.df['y'].tolist()) # train_ds = AugmentedBalanceDataSet(xs=populated_x_train, ys=populated_y_train, args=args, moco=False) # test_ds = NonPopulatedDataset(xs=populated_x_test, ys=populated_y_test) # else: # train_ds = AugmentedBalanceDataSet(xs=x_train, ys=y_train, args=args, moco=False) # test_ds = NonPopulatedDataset(xs=x_test, ys=y_test) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=args.shuffle, drop_last=False, pin_memory=True, num_workers=NUM_WORKER) test_loader = DataLoader(test_ds, batch_size=args.test_batch_size, drop_last=False, pin_memory=True, num_workers=NUM_WORKER) print(f"create train and test dataset took {time.time()-start_time}..") #################################################################################################################### # here is for moco dataset. #################################################################################################################### start_time = time.time() print(f"create moco dataset started...") if not augmented and not populated: train_ds = NonPopulatedDataset(xs=x_train, ys=y_train) test_ds = NonPopulatedDataset(xs=x_test, ys=y_test) if args.moco_fcn: assert not (args.positive_1 and args.positive_2) if populated: prev_moco_ds = PopulatedBalancedDataSet(xs=x_train, ys=y_train, n_classes=nclass, time_window=x_train.shape[1], time_window_ratio=args.time_window_ratio, min_sampling_count=args.sample_count ) print(f'original train dataset_size : {len(x_train)}, test dataset_size : {len(x_test)}') print(f'populated train dataset_size : {len(prev_moco_ds.df)}') if args.positive_1: moco_ds = MocoInterPostiiveSet(prev_moco_ds) elif args.positive_2: moco_ds = MoCoSameClassPositiveSet(prev_moco_ds) if augmented: if populated: xs = np.asarray(moco_ds.df['data_q'].tolist()) # moco_ds.df['data_q'].to_numpy() xs_pair = np.asarray(moco_ds.df['data_k'].tolist()) # moco_ds.df['data_k'].to_numpy() populated_y_train = np.asarray(moco_ds.df['num_class'].tolist()) # moco_ds.df['num_class'].to_numpy() moco_ds = AugmentedBalanceDataSet(xs=xs, xs_pair=xs_pair, ys=populated_y_train, args=args ) else: moco_ds = AugmentedBalanceDataSet(xs=x_train, ys=y_train, args=args ) print(f"create moco dataset took {time.time() - start_time}..") start_time = time.time() print(f"upstream task started..") moco_loader = DataLoader(moco_ds, batch_size=args.batch_size, shuffle=args.shuffle, drop_last=False, pin_memory=True, num_workers=NUM_WORKER) recent_losses = deque(maxlen=5) for epoch in range(args.moco_pretrain_epochs): t = time.time() ############################ pre train ############################# loss_moco_train_list = [] model.train() for i, d in enumerate(moco_loader): tq = torch.as_tensor(d['data_q'], dtype=torch.float32).to(device) tk = torch.as_tensor(d['data_k'], dtype=torch.float32).to(device) ys = torch.as_tensor(d['num_class'], dtype=torch.long).to(device) optimizer.zero_grad() if type(model) == MoCoFcnMixed: logits, labels, _, _ = model(tq=tq, tk=tk, ys=ys) else: logits, labels, _ = model(tq=tq, tk=tk, ys=ys) loss_moco = F.cross_entropy(logits, labels) loss_moco_train_list.append(loss_moco.item()) loss_moco.backward() optimizer.step() if epoch != 0 and epoch % 10==0: torch.save(model.state_dict(), args.upstream_model_path + f'/checkpoint_{epoch}.pt') print('Epoch: {:04d}'.format(epoch + 1), 'loss_moco: {:.8f}'.format(np.average(loss_moco_train_list)), 'time: {:.4f}s'.format(time.time() - t)) wandb.log({ "Epoch": epoch + 1, "loss_moco": np.average(loss_moco_train_list) }) recent_losses.append(np.average(loss_moco_train_list)) print("loss moco: " + str(np.average(loss_moco_train_list))) print(f"upstream task took {time.time() - start_time}.") # do tsne!!! if type(model) == MoCoFcnMixed: model.eval() tsne_embeddings = [] tsne_ys = [] for i, d in enumerate(test_loader): # todo : train_loader for tsne? tq = torch.as_tensor(d['x'], dtype=torch.float32).to(device) tk = torch.as_tensor(d['x'], dtype=torch.float32).to(device) ys = torch.as_tensor(d['y'], dtype=torch.long).to(device) _, _, y_pred, embedding_tsne = model(tq=tq, tk=tk, ys=ys) tsne_embeddings.append(embedding_tsne.cpu().detach().numpy()) tsne_ys.append(ys.cpu().detach().numpy()) tsne_embeddings = np.concatenate(tsne_embeddings, axis=0) tsne_ys = np.concatenate(tsne_ys, axis=0) pca = PCA(n_components=30) data = pca.fit_transform(tsne_embeddings) X_embedded = TSNE(n_components=2).fit_transform(data) from matplotlib import pyplot as plt plt.figure(figsize=(6, 5)) colors = 'r', 'g', 'b', 'c', 'm', 'y', 'k', 'w', 'orange', 'purple' nlables = len(set(tsne_ys)) labels = range(nlables) for i, c, label in zip(labels, colors[:nlables], labels): plt.scatter(X_embedded[tsne_ys == i, 0], X_embedded[tsne_ys == i, 1], c=c, label=label) plt.legend() plt.savefig(args.upstream_model_path + f'{args.project_name}_{args.exp_name}_tsne.png') #*******************########################### fine tuning ############################# ######################################### train ######################################### for epoch in range(args.epochs): t = time.time() loss_train_list = [] accuracy_train_list = [] loss_val_list = [] accuracy_val_list = [] model.train() for i, d in enumerate(train_loader): t = time.time() tq = torch.as_tensor(d['x'], dtype=torch.float32).to(device) tk = torch.as_tensor(d['x'], dtype=torch.float32).to(device) ys = torch.as_tensor(d['y'], dtype=torch.long).to(device) if type(model) == MoCoFcnMixed: _, _, y_pred, _ = model(tq=tq, tk=tk, ys=ys) else: _, _, y_pred = model(tq=tq, tk=tk, ys=ys) optimizer.zero_grad() loss_train = F.cross_entropy(y_pred, ys.squeeze()) loss_train_list.append(loss_train.item()) acc_train = accuracy(y_pred, ys) accuracy_train_list.append(acc_train) loss_train.backward() optimizer.step() ######################################### test ######################################### model.eval() for i, d in enumerate(test_loader): t = time.time() tq = torch.as_tensor(d['x'], dtype=torch.float32).to(device) tk = torch.as_tensor(d['x'], dtype=torch.float32).to(device) ys = torch.as_tensor(d['y'], dtype=torch.long).to(device) if type(model) == MoCoFcnMixed: _, _, y_pred, embedding_tsne = model(tq=tq, tk=tk, ys=ys) else: _, _, y_pred = model(tq=tq, tk=tk, ys=ys) loss_val = F.cross_entropy(y_pred, ys.squeeze()) acc_val = accuracy(y_pred, ys) loss_val_list.append(loss_val.item()) accuracy_val_list.append(acc_val) print('Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.8f}'.format(np.average(loss_train_list)), 'acc_train: {:.4f}'.format(np.average(accuracy_train_list)), 'loss_val: {:.4f}'.format(np.average(loss_val_list)), 'acc_val: {:.4f}'.format(np.average(accuracy_val_list)), 'time: {:.4f}s'.format(time.time() - t)) wandb.log({ "Epoch": epoch + 1, "loss_train": np.average(loss_train_list), "acc_train": np.average(accuracy_train_list), "loss_val": np.average(loss_val_list), "acc_val": np.average(accuracy_val_list), }) # early_stopping needs the validation loss to check if it has decresed, # and if it has, it will make a checkpoint of the current model early_stopping_downstream(np.average(loss_val_list), model) if early_stopping_downstream.early_stop: print("Downstream Task Early stopping") break if np.average(accuracy_val_list) > test_best_possible: test_best_possible = np.average(accuracy_val_list) if best_so_far > np.average(loss_train_list): best_so_far = np.average(loss_train_list) test_acc = np.average(accuracy_val_list) print("test_acc: " + str(test_acc)) print("best possible: " + str(test_best_possible)) wandb.run.summary.update({"accuracy": str(test_acc)}) wandb.run.summary.update({"best_accuracy": str(test_best_possible)}) wandb.run.summary.update({"recent_loss_moco": str(sum(loss_moco_train_list))}) # *******************########################### fine tuning ############################# else: print(f'original train dataset_size : {len(x_train)}, test dataset_size : {len(x_test)}') if populated: print(f'populated_dataset_size : {len(train_ds)}') dataset_size = len(train_ds) for epoch in range(args.epochs): print(f'epoch {epoch} started..') t = time.time() ######################################### train ######################################### loss_train_list = [] accuracy_train_list = [] loss_val_list = [] accuracy_val_list = [] model.train() for batch_index, sample in enumerate(train_loader): xs = sample['x'].float().to(device) ys = sample['y'].long().to(device) optimizer.zero_grad() y_pred = model(xs) loss_train = F.cross_entropy(y_pred, ys.squeeze()) loss_train_list.append(loss_train.item()) acc_train = accuracy(y_pred, ys) accuracy_train_list.append(acc_train) loss_train.backward() optimizer.step() ######################################### test ######################################### model.eval() for batch_index, sample in enumerate(test_loader): xs = sample['x'].float().to(device) ys = sample['y'].long().to(device) y_pred = model(xs) loss_val = F.cross_entropy(y_pred, ys.squeeze()) acc_val = accuracy(y_pred, ys) loss_val_list.append(loss_val.item()) accuracy_val_list.append(acc_val) print('Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.8f}'.format(np.average(loss_train_list)), 'acc_train: {:.4f}'.format(np.average(accuracy_train_list)), 'loss_val: {:.4f}'.format(np.average(loss_val_list)), 'acc_val: {:.4f}'.format(np.average(accuracy_val_list)), 'time: {:.4f}s'.format(time.time() - t)) wandb.log({ "Epoch": epoch + 1, "loss_train": np.average(loss_train_list), "acc_train": np.average(accuracy_train_list), "loss_val": np.average(loss_val_list), "acc_val": np.average(accuracy_val_list), }) if np.average(accuracy_val_list) > test_best_possible: test_best_possible = np.average(accuracy_val_list) if best_so_far > np.average(loss_train_list): best_so_far = np.average(loss_train_list) test_acc = np.average(accuracy_val_list) print("test_acc: " + str(test_acc)) print("best possible: " + str(test_best_possible)) wandb.run.summary.update({"best_accuracy": str(test_best_possible)}) wandb.run.summary.update({"accuracy": str(test_acc)})
def train(self, train_dataset, epochs, gradient_accumalation_step, dev_dataset=None, train_batch_size=16, dev_batch_size=32, num_workers=2, gradient_clipping=5): train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers) early_stopping = EarlyStopping(not_improve_step=3, verbose=True) dev_loader = None if dev_dataset is not None: dev_loader = DataLoader(dataset=dev_dataset, batch_size=dev_batch_size, shuffle=False, num_workers=num_workers) self.model.to(self.device) global_step = 0 best_loss = 1000 for i in range(epochs): self.model.train() train_epoch_loss = Loss() self.metric.clear_memory() for batch in tqdm(train_loader): global_step += 1 step_loss, y_true, y_pred = self.iter(batch) # if y_true is not None and y_pred is not None: # self.metric.write(y_true,y_pred) step_loss = step_loss / gradient_accumalation_step step_loss.backward() train_epoch_loss.write(step_loss.item()) self.log.write('training_loss', step_loss.item(), global_step) if global_step % gradient_accumalation_step == 0: torch.nn.utils.clip_grad_norm(self.model.parameters(), gradient_clipping) self.optimizer.step() self.model.zero_grad() if self.lr_scheduler is not None: self.lr_scheduler.step() train_loss = train_epoch_loss.average() # train_result = self.metric.average() # for tag, item in train_result.items(): # self.log.write(tag,item,i+1) if dev_loader is not None: self.model.eval() dev_epoch_loss = Loss() self.metric.clear_memory() with torch.no_grad(): for batch in tqdm(dev_loader): step_loss, y_true, y_pred = self.iter(batch, is_train=False) # self.metric.write(y_true,y_pred) dev_epoch_loss.write(step_loss.item()) self.log.write('validation_loss', step_loss.item(), global_step) dev_loss = dev_epoch_loss.average() early_stopping.step(val=dev_loss) if dev_loss <= best_loss: best_loss = dev_loss model_path = 'model_epoch_{0}_best_loss{1:.2f}.pth'.format( i + 1, best_loss) self.save_model(model_path) # dev_result = self.metric.average() # for tag, item in dev_result.items(): # self.log.write(tag,item,i+1) print( 'epoch - {0}, global_step:{1}, train_loss:{2:.2f}, dev_loss:{3:.2f}' .format(i + 1, global_step, train_loss, dev_loss)) # print(train_result, dev_result) else: if train_loss <= best_loss: best_loss = train_loss model_path = 'model_epoch_{0}_best_loss{1:.2f}.pth'.format( i + 1, best_loss) self.save_model(model_path) print('epoch - {0},global_step:{1}, train_loss:{2:.2f}'.format( i + 1, global_step, train_loss))